kopa/kopa.py

237 lines
8.9 KiB
Python
Raw Permalink Normal View History

2023-10-11 11:51:08 +08:00
import torch
import torch.nn as nn
from typing import Optional, List, Union, Tuple
from transformers import LlamaForCausalLM
class KoPA(nn.Module):
def __init__(
self,
model: LlamaForCausalLM
) -> None:
super(KoPA, self).__init__()
self.llama_model = model
2025-03-17 14:49:22 +08:00
self.embeddings = nn.Embedding(100, 3072)
# self.embeddings = PrefixKGEmbedding(
# num_ent=2034,
# num_rel=42,
# dim_llm=3072,
# num_prefix=1
# )
2023-10-11 11:51:08 +08:00
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
embedding_ids: torch.LongTensor = None
):
2025-03-17 14:49:22 +08:00
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
2023-10-11 11:51:08 +08:00
kg_embeds = self.embeddings(embedding_ids)
batch_size, seq_len, _ = kg_embeds.shape
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
input_embeds = torch.cat((kg_embeds, token_embeds), dim=1)
prefix_mask = torch.ones((batch_size, seq_len))
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1)
new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1)
2025-03-17 14:49:22 +08:00
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
2023-10-11 11:51:08 +08:00
return self.llama_model(
input_ids=None,
attention_mask=new_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=input_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class KoPAWithAdapter(nn.Module):
2025-03-17 14:49:22 +08:00
def __init__(self, model, num_prefix, tokenizer=None):
super().__init__()
self.model = model
self.num_prefix = num_prefix
hidden_size = model.config.hidden_size
# 使用tokenizer获取vocab_size
vocab_size = tokenizer.vocab_size if tokenizer else 32000
self.static_prefix_embedding = nn.Embedding(vocab_size, hidden_size)
self.embeddings = self.static_prefix_embedding # 保留这个属性
self.sensor_mlp = nn.Sequential(
nn.Linear(3, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, hidden_size)
)
# 添加LayerNorm
self.norm = nn.LayerNorm(hidden_size)
print(f"[INFO] 模型初始化: hidden_size={hidden_size}, vocab_size={vocab_size}")
def forward(self, input_ids, attention_mask, static_prefix=None, sensor_data=None, labels=None, **kwargs):
batch_size, seq_len = input_ids.shape
device = input_ids.device
# 确保所有组件在同一设备上
self.static_prefix_embedding = self.static_prefix_embedding.to(device)
self.sensor_mlp = self.sensor_mlp.to(device)
self.norm = self.norm.to(device)
# 处理静态前缀
if static_prefix is not None:
static_prefix = static_prefix.to(device)
static_prefix = self.static_prefix_embedding(static_prefix)
else:
static_prefix = torch.zeros(
(batch_size, self.num_prefix, self.model.config.hidden_size),
device=device
2023-10-11 11:51:08 +08:00
)
2025-03-17 14:49:22 +08:00
# 处理动态前缀
if sensor_data is not None:
sensor_data = sensor_data.to(device)
if sensor_data.dim() == 1:
sensor_data = sensor_data.unsqueeze(0)
try:
dynamic_prefix = self.sensor_mlp(sensor_data)
dynamic_prefix = dynamic_prefix.unsqueeze(1).expand(-1, self.num_prefix, -1)
except Exception as e:
print(f"[ERROR] sensor_mlp处理失败: {e}")
dynamic_prefix = torch.zeros_like(static_prefix)
2023-10-11 11:51:08 +08:00
else:
2025-03-17 14:49:22 +08:00
dynamic_prefix = torch.zeros_like(static_prefix)
# 混合前缀
alpha = 0.6
final_prefix = alpha * static_prefix + (1 - alpha) * dynamic_prefix
final_prefix = self.norm(final_prefix)
# 处理token嵌入
token_embeds = self.model.model.embed_tokens(input_ids)
input_embeds = torch.cat((final_prefix, token_embeds), dim=1)
# 扩展注意力掩码
prefix_attention_mask = torch.ones(
(batch_size, self.num_prefix),
dtype=attention_mask.dtype,
device=device
2023-10-11 11:51:08 +08:00
)
2025-03-17 14:49:22 +08:00
extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
2023-10-11 11:51:08 +08:00
2025-03-17 14:49:22 +08:00
# ✨ 关键修复: 处理标签
if labels is not None:
# 为前缀部分创建-100的标签表示忽略
prefix_labels = torch.full(
(batch_size, self.num_prefix),
fill_value=-100, # -100表示忽略这些位置的损失
dtype=labels.dtype,
device=device
)
# 扩展标签
extended_labels = torch.cat((prefix_labels, labels), dim=1)
else:
extended_labels = None
2023-10-11 11:51:08 +08:00
2025-03-17 14:49:22 +08:00
# 调试输出
# print(f"[DEBUG] 原始输入大小: {input_ids.shape}")
# print(f"[DEBUG] 扩展嵌入大小: {input_embeds.shape}")
# print(f"[DEBUG] 扩展掩码大小: {extended_attention_mask.shape}")
# if extended_labels is not None:
# print(f"[DEBUG] 扩展标签大小: {extended_labels.shape}")
2023-10-11 11:51:08 +08:00
2025-03-17 14:49:22 +08:00
# 确保不提供input_ids
if 'input_ids' in kwargs:
del kwargs['input_ids']
2023-10-11 11:51:08 +08:00
2025-03-17 14:49:22 +08:00
# ✨ 传递扩展后的标签
return self.model(
inputs_embeds=input_embeds,
attention_mask=extended_attention_mask,
labels=extended_labels, # 这是关键修改
use_cache=False,
**kwargs)
# class PrefixKGEmbedding(nn.Module):
# def __init__(
# self,
# num_ent,
# num_rel,
# dim_llm,
# num_prefix
# ):
# super(PrefixKGEmbedding, self).__init__()
# self.emb_dim = num_prefix * dim_llm
# self.ent_embeddings = nn.Embedding(num_ent, self.emb_dim)
# self.rel_embeddings = nn.Embedding(num_rel, self.emb_dim)
#
#
# def forward(self, triple_ids):
# head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
# h = self.ent_embeddings(head)
# r = self.rel_embeddings(relation)
# t = self.ent_embeddings(tail)
# prefix = torch.stack((h, r, t), dim=1)
# return prefix
2023-10-11 11:51:08 +08:00
class PretrainKGEmbedding(nn.Module):
def __init__(
self,
pretrain_ent_embs,
pretrain_rel_embs,
dim_llm,
num_prefix
):
super(PretrainKGEmbedding, self).__init__()
self.num_prefix = num_prefix
self.llm_dim = dim_llm
self.emb_dim = num_prefix * dim_llm
self.ent_embeddings = nn.Embedding.from_pretrained(pretrain_ent_embs)
self.rel_embeddings = nn.Embedding.from_pretrained(pretrain_rel_embs)
self.pretrain_dim = self.ent_embeddings.weight.shape[1]
# Froze the pretrain embeddings
self.ent_embeddings.requires_grad_(False)
self.rel_embeddings.requires_grad_(False)
self.adapter = nn.Linear(self.pretrain_dim, self.emb_dim)
2025-03-17 14:49:22 +08:00
2023-10-11 11:51:08 +08:00
def forward(self, triple_ids):
# main training stage
if triple_ids.shape[1] == 3:
head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
h = self.ent_embeddings(head)
r = self.rel_embeddings(relation)
t = self.ent_embeddings(tail)
pretrain_embs = torch.stack((h, r, t), dim=1)
prefix = self.adapter(pretrain_embs).reshape(-1, 3*self.num_prefix, self.llm_dim)
return prefix
# entity-aware pre-funing
else:
ent = triple_ids.reshape(-1,)
emb = self.ent_embeddings(ent)
prefix = self.adapter(emb).reshape(-1, self.num_prefix, self.llm_dim)
# print(prefix.shape)
return prefix