kopa/.ipynb_checkpoints/kopa-checkpoint.py
2025-03-17 20:17:41 +08:00

298 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from typing import Optional, List, Union, Tuple
from transformers import LlamaForCausalLM
class KoPA(nn.Module):
def __init__(
self,
model
) -> None:
super(KoPA, self).__init__()
self.llama_model = model
for param in self.model.parameters():
param.requires_grad = False
# Only keep gradients for the adapter parts
self.num_prefix = num_prefix
hidden_size = model.config.hidden_size
self.embeddings = nn.Embedding(100, 4096)
for param in model.parameters():
param.requires_grad = False
# Only enable gradients for adapter components
self.static_prefix_embedding.requires_grad_(True)
self.sensor_mlp.requires_grad_(True)
self.norm.requires_grad_(True)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
embedding_ids: torch.LongTensor = None
):
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
kg_embeds = self.embeddings(embedding_ids)
batch_size, seq_len, _ = kg_embeds.shape
if hasattr(self.llama_model, 'transformer'):
# Qwen模型
token_embeds = self.llama_model.transformer.wte(input_ids)
elif hasattr(self.llama_model, 'model') and hasattr(self.llama_model.model, 'embed_tokens'):
# 原始路径
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
else:
# 添加调试代码
print("无法找到模型嵌入层,尝试检测模型结构...")
raise ValueError("模型结构不兼容")
input_embeds = torch.cat((kg_embeds, token_embeds), dim=1)
prefix_mask = torch.ones((batch_size, seq_len))
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1)
new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1)
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
return self.llama_model(
input_ids=None,
attention_mask=new_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=input_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class KoPAWithAdapter(nn.Module):
def __init__(self, model, num_prefix, tokenizer=None):
super().__init__()
self.model = model
self.num_prefix = num_prefix
hidden_size = model.config.hidden_size
# 打印模型信息以便调试
print(f"[INFO] 初始化KoPAWithAdapter模型类型: {type(model).__name__}")
# 使用tokenizer获取vocab_size
vocab_size = tokenizer.vocab_size if tokenizer else 151936 # Qwen2.5的默认词表大小
print(f"[INFO] 使用词表大小: {vocab_size}")
self.static_prefix_embedding = nn.Embedding(vocab_size, hidden_size)
self.embeddings = self.static_prefix_embedding # 保留这个属性
self.sensor_mlp = nn.Sequential(
nn.Linear(3, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, hidden_size)
)
# 添加LayerNorm
self.norm = nn.LayerNorm(hidden_size)
print(f"[INFO] 模型初始化: hidden_size={hidden_size}, vocab_size={vocab_size}")
# 检测模型嵌入层路径
self._detect_embedding_path()
def _detect_embedding_path(self):
"""检测模型的嵌入层路径"""
self.embedding_path = None
# 尝试不同的常见路径
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'):
self.embedding_path = "transformer.wte"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'):
self.embedding_path = "model.embed_tokens"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'):
self.embedding_path = "model.model.model.embed_tokens"
if self.embedding_path:
print(f"[INFO] 检测到嵌入层路径: {self.embedding_path}")
else:
print("[WARNING] 无法自动检测嵌入层路径,将在前向传播中尝试多种路径")
def forward(self, input_ids, attention_mask, static_prefix=None, sensor_data=None, labels=None, **kwargs):
batch_size, seq_len = input_ids.shape
device = input_ids.device
# 确保所有组件在同一设备上
self.static_prefix_embedding = self.static_prefix_embedding.to(device)
self.sensor_mlp = self.sensor_mlp.to(device)
self.norm = self.norm.to(device)
# 处理静态前缀
if static_prefix is not None:
static_prefix = static_prefix.to(device)
static_prefix = self.static_prefix_embedding(static_prefix)
else:
static_prefix = torch.zeros(
(batch_size, self.num_prefix, self.model.config.hidden_size),
device=device
)
# 处理动态前缀
if sensor_data is not None:
sensor_data = sensor_data.to(device)
if sensor_data.dim() == 1:
sensor_data = sensor_data.unsqueeze(0)
try:
dynamic_prefix = self.sensor_mlp(sensor_data)
dynamic_prefix = dynamic_prefix.unsqueeze(1).expand(-1, self.num_prefix, -1)
except Exception as e:
print(f"[ERROR] sensor_mlp处理失败: {e}")
dynamic_prefix = torch.zeros_like(static_prefix)
else:
dynamic_prefix = torch.zeros_like(static_prefix)
# 混合前缀
alpha = 0.6
final_prefix = alpha * static_prefix + (1 - alpha) * dynamic_prefix
final_prefix = self.norm(final_prefix)
# 处理token嵌入 - 根据检测到的路径获取嵌入
try:
if self.embedding_path == "transformer.wte":
token_embeds = self.model.transformer.wte(input_ids)
elif self.embedding_path == "model.embed_tokens":
token_embeds = self.model.model.embed_tokens(input_ids)
elif self.embedding_path == "model.model.model.embed_tokens":
token_embeds = self.model.model.model.embed_tokens(input_ids)
else:
# 尝试多种可能的路径
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'):
token_embeds = self.model.transformer.wte(input_ids)
self.embedding_path = "transformer.wte"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'):
token_embeds = self.model.model.embed_tokens(input_ids)
self.embedding_path = "model.embed_tokens"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'):
token_embeds = self.model.model.model.embed_tokens(input_ids)
self.embedding_path = "model.model.model.embed_tokens"
else:
raise ValueError("无法找到嵌入层路径")
print(f"[INFO] 成功找到嵌入层路径: {self.embedding_path}")
except Exception as e:
print(f"[ERROR] 获取token嵌入失败: {e}")
# 打印模型结构以帮助调试
print("模型结构:")
for name, _ in self.model.named_modules():
if 'embed' in name or 'wte' in name:
print(f" - {name}")
raise
input_embeds = torch.cat((final_prefix, token_embeds), dim=1)
# 扩展注意力掩码
prefix_attention_mask = torch.ones(
(batch_size, self.num_prefix),
dtype=attention_mask.dtype,
device=device
)
extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
# 处理标签
if labels is not None:
# 为前缀部分创建-100的标签表示忽略
prefix_labels = torch.full(
(batch_size, self.num_prefix),
fill_value=-100, # -100表示忽略这些位置的损失
dtype=labels.dtype,
device=device
)
# 扩展标签
extended_labels = torch.cat((prefix_labels, labels), dim=1)
else:
extended_labels = None
# 确保不提供input_ids
if 'input_ids' in kwargs:
del kwargs['input_ids']
# 传递扩展后的标签
return self.model(
inputs_embeds=input_embeds,
attention_mask=extended_attention_mask,
labels=extended_labels,
use_cache=False,
**kwargs)
# class PrefixKGEmbedding(nn.Module):
# def __init__(
# self,
# num_ent,
# num_rel,
# dim_llm,
# num_prefix
# ):
# super(PrefixKGEmbedding, self).__init__()
# self.emb_dim = num_prefix * dim_llm
# self.ent_embeddings = nn.Embedding(num_ent, self.emb_dim)
# self.rel_embeddings = nn.Embedding(num_rel, self.emb_dim)
#
#
# def forward(self, triple_ids):
# head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
# h = self.ent_embeddings(head)
# r = self.rel_embeddings(relation)
# t = self.ent_embeddings(tail)
# prefix = torch.stack((h, r, t), dim=1)
# return prefix
class PretrainKGEmbedding(nn.Module):
def __init__(
self,
pretrain_ent_embs,
pretrain_rel_embs,
dim_llm,
num_prefix
):
super(PretrainKGEmbedding, self).__init__()
self.num_prefix = num_prefix
self.llm_dim = dim_llm
self.emb_dim = num_prefix * dim_llm
self.ent_embeddings = nn.Embedding.from_pretrained(pretrain_ent_embs)
self.rel_embeddings = nn.Embedding.from_pretrained(pretrain_rel_embs)
self.pretrain_dim = self.ent_embeddings.weight.shape[1]
# Froze the pretrain embeddings
self.ent_embeddings.requires_grad_(False)
self.rel_embeddings.requires_grad_(False)
self.adapter = nn.Linear(self.pretrain_dim, self.emb_dim)
def forward(self, triple_ids):
# main training stage
if triple_ids.shape[1] == 3:
head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
h = self.ent_embeddings(head)
r = self.rel_embeddings(relation)
t = self.ent_embeddings(tail)
pretrain_embs = torch.stack((h, r, t), dim=1)
prefix = self.adapter(pretrain_embs).reshape(-1, 3*self.num_prefix, self.llm_dim)
return prefix
# entity-aware pre-funing
else:
ent = triple_ids.reshape(-1,)
emb = self.ent_embeddings(ent)
prefix = self.adapter(emb).reshape(-1, self.num_prefix, self.llm_dim)
# print(prefix.shape)
return prefix