298 lines
12 KiB
Python
298 lines
12 KiB
Python
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
from typing import Optional, List, Union, Tuple
|
|||
|
|
|||
|
from transformers import LlamaForCausalLM
|
|||
|
|
|||
|
|
|||
|
class KoPA(nn.Module):
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
model
|
|||
|
) -> None:
|
|||
|
super(KoPA, self).__init__()
|
|||
|
self.llama_model = model
|
|||
|
for param in self.model.parameters():
|
|||
|
param.requires_grad = False
|
|||
|
|
|||
|
# Only keep gradients for the adapter parts
|
|||
|
self.num_prefix = num_prefix
|
|||
|
hidden_size = model.config.hidden_size
|
|||
|
self.embeddings = nn.Embedding(100, 4096)
|
|||
|
for param in model.parameters():
|
|||
|
param.requires_grad = False
|
|||
|
|
|||
|
# Only enable gradients for adapter components
|
|||
|
self.static_prefix_embedding.requires_grad_(True)
|
|||
|
self.sensor_mlp.requires_grad_(True)
|
|||
|
self.norm.requires_grad_(True)
|
|||
|
def forward(
|
|||
|
self,
|
|||
|
input_ids: torch.LongTensor = None,
|
|||
|
attention_mask: Optional[torch.Tensor] = None,
|
|||
|
position_ids: Optional[torch.LongTensor] = None,
|
|||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|||
|
labels: Optional[torch.LongTensor] = None,
|
|||
|
use_cache: Optional[bool] = None,
|
|||
|
output_attentions: Optional[bool] = None,
|
|||
|
output_hidden_states: Optional[bool] = None,
|
|||
|
return_dict: Optional[bool] = None,
|
|||
|
embedding_ids: torch.LongTensor = None
|
|||
|
):
|
|||
|
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
|
|||
|
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
|
|||
|
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
|
|||
|
kg_embeds = self.embeddings(embedding_ids)
|
|||
|
batch_size, seq_len, _ = kg_embeds.shape
|
|||
|
if hasattr(self.llama_model, 'transformer'):
|
|||
|
# Qwen模型
|
|||
|
token_embeds = self.llama_model.transformer.wte(input_ids)
|
|||
|
elif hasattr(self.llama_model, 'model') and hasattr(self.llama_model.model, 'embed_tokens'):
|
|||
|
# 原始路径
|
|||
|
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
|
|||
|
else:
|
|||
|
# 添加调试代码
|
|||
|
print("无法找到模型嵌入层,尝试检测模型结构...")
|
|||
|
raise ValueError("模型结构不兼容")
|
|||
|
input_embeds = torch.cat((kg_embeds, token_embeds), dim=1)
|
|||
|
prefix_mask = torch.ones((batch_size, seq_len))
|
|||
|
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
|
|||
|
new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1)
|
|||
|
new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1)
|
|||
|
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
|
|||
|
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
|
|||
|
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
|
|||
|
|
|||
|
return self.llama_model(
|
|||
|
input_ids=None,
|
|||
|
attention_mask=new_attention_mask,
|
|||
|
position_ids=position_ids,
|
|||
|
past_key_values=past_key_values,
|
|||
|
inputs_embeds=input_embeds,
|
|||
|
labels=new_labels,
|
|||
|
use_cache=use_cache,
|
|||
|
output_attentions=output_attentions,
|
|||
|
output_hidden_states=output_hidden_states,
|
|||
|
return_dict=return_dict,
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
class KoPAWithAdapter(nn.Module):
|
|||
|
def __init__(self, model, num_prefix, tokenizer=None):
|
|||
|
super().__init__()
|
|||
|
self.model = model
|
|||
|
self.num_prefix = num_prefix
|
|||
|
hidden_size = model.config.hidden_size
|
|||
|
|
|||
|
# 打印模型信息以便调试
|
|||
|
print(f"[INFO] 初始化KoPAWithAdapter,模型类型: {type(model).__name__}")
|
|||
|
|
|||
|
# 使用tokenizer获取vocab_size
|
|||
|
vocab_size = tokenizer.vocab_size if tokenizer else 151936 # Qwen2.5的默认词表大小
|
|||
|
print(f"[INFO] 使用词表大小: {vocab_size}")
|
|||
|
|
|||
|
self.static_prefix_embedding = nn.Embedding(vocab_size, hidden_size)
|
|||
|
self.embeddings = self.static_prefix_embedding # 保留这个属性
|
|||
|
|
|||
|
self.sensor_mlp = nn.Sequential(
|
|||
|
nn.Linear(3, hidden_size // 2),
|
|||
|
nn.ReLU(),
|
|||
|
nn.Dropout(0.1),
|
|||
|
nn.Linear(hidden_size // 2, hidden_size)
|
|||
|
)
|
|||
|
|
|||
|
# 添加LayerNorm
|
|||
|
self.norm = nn.LayerNorm(hidden_size)
|
|||
|
|
|||
|
print(f"[INFO] 模型初始化: hidden_size={hidden_size}, vocab_size={vocab_size}")
|
|||
|
|
|||
|
# 检测模型嵌入层路径
|
|||
|
self._detect_embedding_path()
|
|||
|
|
|||
|
def _detect_embedding_path(self):
|
|||
|
"""检测模型的嵌入层路径"""
|
|||
|
self.embedding_path = None
|
|||
|
|
|||
|
# 尝试不同的常见路径
|
|||
|
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'):
|
|||
|
self.embedding_path = "transformer.wte"
|
|||
|
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'):
|
|||
|
self.embedding_path = "model.embed_tokens"
|
|||
|
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'):
|
|||
|
self.embedding_path = "model.model.model.embed_tokens"
|
|||
|
|
|||
|
if self.embedding_path:
|
|||
|
print(f"[INFO] 检测到嵌入层路径: {self.embedding_path}")
|
|||
|
else:
|
|||
|
print("[WARNING] 无法自动检测嵌入层路径,将在前向传播中尝试多种路径")
|
|||
|
|
|||
|
def forward(self, input_ids, attention_mask, static_prefix=None, sensor_data=None, labels=None, **kwargs):
|
|||
|
batch_size, seq_len = input_ids.shape
|
|||
|
device = input_ids.device
|
|||
|
|
|||
|
# 确保所有组件在同一设备上
|
|||
|
self.static_prefix_embedding = self.static_prefix_embedding.to(device)
|
|||
|
self.sensor_mlp = self.sensor_mlp.to(device)
|
|||
|
self.norm = self.norm.to(device)
|
|||
|
|
|||
|
# 处理静态前缀
|
|||
|
if static_prefix is not None:
|
|||
|
static_prefix = static_prefix.to(device)
|
|||
|
static_prefix = self.static_prefix_embedding(static_prefix)
|
|||
|
else:
|
|||
|
static_prefix = torch.zeros(
|
|||
|
(batch_size, self.num_prefix, self.model.config.hidden_size),
|
|||
|
device=device
|
|||
|
)
|
|||
|
|
|||
|
# 处理动态前缀
|
|||
|
if sensor_data is not None:
|
|||
|
sensor_data = sensor_data.to(device)
|
|||
|
|
|||
|
if sensor_data.dim() == 1:
|
|||
|
sensor_data = sensor_data.unsqueeze(0)
|
|||
|
|
|||
|
try:
|
|||
|
dynamic_prefix = self.sensor_mlp(sensor_data)
|
|||
|
dynamic_prefix = dynamic_prefix.unsqueeze(1).expand(-1, self.num_prefix, -1)
|
|||
|
except Exception as e:
|
|||
|
print(f"[ERROR] sensor_mlp处理失败: {e}")
|
|||
|
dynamic_prefix = torch.zeros_like(static_prefix)
|
|||
|
else:
|
|||
|
dynamic_prefix = torch.zeros_like(static_prefix)
|
|||
|
|
|||
|
# 混合前缀
|
|||
|
alpha = 0.6
|
|||
|
final_prefix = alpha * static_prefix + (1 - alpha) * dynamic_prefix
|
|||
|
final_prefix = self.norm(final_prefix)
|
|||
|
|
|||
|
# 处理token嵌入 - 根据检测到的路径获取嵌入
|
|||
|
try:
|
|||
|
if self.embedding_path == "transformer.wte":
|
|||
|
token_embeds = self.model.transformer.wte(input_ids)
|
|||
|
elif self.embedding_path == "model.embed_tokens":
|
|||
|
token_embeds = self.model.model.embed_tokens(input_ids)
|
|||
|
elif self.embedding_path == "model.model.model.embed_tokens":
|
|||
|
token_embeds = self.model.model.model.embed_tokens(input_ids)
|
|||
|
else:
|
|||
|
# 尝试多种可能的路径
|
|||
|
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'):
|
|||
|
token_embeds = self.model.transformer.wte(input_ids)
|
|||
|
self.embedding_path = "transformer.wte"
|
|||
|
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'):
|
|||
|
token_embeds = self.model.model.embed_tokens(input_ids)
|
|||
|
self.embedding_path = "model.embed_tokens"
|
|||
|
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'):
|
|||
|
token_embeds = self.model.model.model.embed_tokens(input_ids)
|
|||
|
self.embedding_path = "model.model.model.embed_tokens"
|
|||
|
else:
|
|||
|
raise ValueError("无法找到嵌入层路径")
|
|||
|
print(f"[INFO] 成功找到嵌入层路径: {self.embedding_path}")
|
|||
|
except Exception as e:
|
|||
|
print(f"[ERROR] 获取token嵌入失败: {e}")
|
|||
|
# 打印模型结构以帮助调试
|
|||
|
print("模型结构:")
|
|||
|
for name, _ in self.model.named_modules():
|
|||
|
if 'embed' in name or 'wte' in name:
|
|||
|
print(f" - {name}")
|
|||
|
raise
|
|||
|
|
|||
|
input_embeds = torch.cat((final_prefix, token_embeds), dim=1)
|
|||
|
|
|||
|
# 扩展注意力掩码
|
|||
|
prefix_attention_mask = torch.ones(
|
|||
|
(batch_size, self.num_prefix),
|
|||
|
dtype=attention_mask.dtype,
|
|||
|
device=device
|
|||
|
)
|
|||
|
extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
|||
|
|
|||
|
# 处理标签
|
|||
|
if labels is not None:
|
|||
|
# 为前缀部分创建-100的标签(表示忽略)
|
|||
|
prefix_labels = torch.full(
|
|||
|
(batch_size, self.num_prefix),
|
|||
|
fill_value=-100, # -100表示忽略这些位置的损失
|
|||
|
dtype=labels.dtype,
|
|||
|
device=device
|
|||
|
)
|
|||
|
# 扩展标签
|
|||
|
extended_labels = torch.cat((prefix_labels, labels), dim=1)
|
|||
|
else:
|
|||
|
extended_labels = None
|
|||
|
|
|||
|
# 确保不提供input_ids
|
|||
|
if 'input_ids' in kwargs:
|
|||
|
del kwargs['input_ids']
|
|||
|
|
|||
|
# 传递扩展后的标签
|
|||
|
return self.model(
|
|||
|
inputs_embeds=input_embeds,
|
|||
|
attention_mask=extended_attention_mask,
|
|||
|
labels=extended_labels,
|
|||
|
use_cache=False,
|
|||
|
**kwargs)
|
|||
|
|
|||
|
# class PrefixKGEmbedding(nn.Module):
|
|||
|
# def __init__(
|
|||
|
# self,
|
|||
|
# num_ent,
|
|||
|
# num_rel,
|
|||
|
# dim_llm,
|
|||
|
# num_prefix
|
|||
|
# ):
|
|||
|
# super(PrefixKGEmbedding, self).__init__()
|
|||
|
# self.emb_dim = num_prefix * dim_llm
|
|||
|
# self.ent_embeddings = nn.Embedding(num_ent, self.emb_dim)
|
|||
|
# self.rel_embeddings = nn.Embedding(num_rel, self.emb_dim)
|
|||
|
#
|
|||
|
#
|
|||
|
# def forward(self, triple_ids):
|
|||
|
# head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
|
|||
|
# h = self.ent_embeddings(head)
|
|||
|
# r = self.rel_embeddings(relation)
|
|||
|
# t = self.ent_embeddings(tail)
|
|||
|
# prefix = torch.stack((h, r, t), dim=1)
|
|||
|
# return prefix
|
|||
|
|
|||
|
class PretrainKGEmbedding(nn.Module):
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
pretrain_ent_embs,
|
|||
|
pretrain_rel_embs,
|
|||
|
dim_llm,
|
|||
|
num_prefix
|
|||
|
):
|
|||
|
super(PretrainKGEmbedding, self).__init__()
|
|||
|
self.num_prefix = num_prefix
|
|||
|
self.llm_dim = dim_llm
|
|||
|
self.emb_dim = num_prefix * dim_llm
|
|||
|
self.ent_embeddings = nn.Embedding.from_pretrained(pretrain_ent_embs)
|
|||
|
self.rel_embeddings = nn.Embedding.from_pretrained(pretrain_rel_embs)
|
|||
|
self.pretrain_dim = self.ent_embeddings.weight.shape[1]
|
|||
|
# Froze the pretrain embeddings
|
|||
|
self.ent_embeddings.requires_grad_(False)
|
|||
|
self.rel_embeddings.requires_grad_(False)
|
|||
|
self.adapter = nn.Linear(self.pretrain_dim, self.emb_dim)
|
|||
|
|
|||
|
|
|||
|
def forward(self, triple_ids):
|
|||
|
# main training stage
|
|||
|
if triple_ids.shape[1] == 3:
|
|||
|
head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
|
|||
|
h = self.ent_embeddings(head)
|
|||
|
r = self.rel_embeddings(relation)
|
|||
|
t = self.ent_embeddings(tail)
|
|||
|
pretrain_embs = torch.stack((h, r, t), dim=1)
|
|||
|
prefix = self.adapter(pretrain_embs).reshape(-1, 3*self.num_prefix, self.llm_dim)
|
|||
|
return prefix
|
|||
|
# entity-aware pre-funing
|
|||
|
else:
|
|||
|
ent = triple_ids.reshape(-1,)
|
|||
|
emb = self.ent_embeddings(ent)
|
|||
|
prefix = self.adapter(emb).reshape(-1, self.num_prefix, self.llm_dim)
|
|||
|
# print(prefix.shape)
|
|||
|
return prefix
|
|||
|
|