kopa/kopa.py
2025-03-17 14:49:22 +08:00

237 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from typing import Optional, List, Union, Tuple
from transformers import LlamaForCausalLM
class KoPA(nn.Module):
def __init__(
self,
model: LlamaForCausalLM
) -> None:
super(KoPA, self).__init__()
self.llama_model = model
self.embeddings = nn.Embedding(100, 3072)
# self.embeddings = PrefixKGEmbedding(
# num_ent=2034,
# num_rel=42,
# dim_llm=3072,
# num_prefix=1
# )
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
embedding_ids: torch.LongTensor = None
):
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
kg_embeds = self.embeddings(embedding_ids)
batch_size, seq_len, _ = kg_embeds.shape
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
input_embeds = torch.cat((kg_embeds, token_embeds), dim=1)
prefix_mask = torch.ones((batch_size, seq_len))
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1)
new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1)
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
return self.llama_model(
input_ids=None,
attention_mask=new_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=input_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class KoPAWithAdapter(nn.Module):
def __init__(self, model, num_prefix, tokenizer=None):
super().__init__()
self.model = model
self.num_prefix = num_prefix
hidden_size = model.config.hidden_size
# 使用tokenizer获取vocab_size
vocab_size = tokenizer.vocab_size if tokenizer else 32000
self.static_prefix_embedding = nn.Embedding(vocab_size, hidden_size)
self.embeddings = self.static_prefix_embedding # 保留这个属性
self.sensor_mlp = nn.Sequential(
nn.Linear(3, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, hidden_size)
)
# 添加LayerNorm
self.norm = nn.LayerNorm(hidden_size)
print(f"[INFO] 模型初始化: hidden_size={hidden_size}, vocab_size={vocab_size}")
def forward(self, input_ids, attention_mask, static_prefix=None, sensor_data=None, labels=None, **kwargs):
batch_size, seq_len = input_ids.shape
device = input_ids.device
# 确保所有组件在同一设备上
self.static_prefix_embedding = self.static_prefix_embedding.to(device)
self.sensor_mlp = self.sensor_mlp.to(device)
self.norm = self.norm.to(device)
# 处理静态前缀
if static_prefix is not None:
static_prefix = static_prefix.to(device)
static_prefix = self.static_prefix_embedding(static_prefix)
else:
static_prefix = torch.zeros(
(batch_size, self.num_prefix, self.model.config.hidden_size),
device=device
)
# 处理动态前缀
if sensor_data is not None:
sensor_data = sensor_data.to(device)
if sensor_data.dim() == 1:
sensor_data = sensor_data.unsqueeze(0)
try:
dynamic_prefix = self.sensor_mlp(sensor_data)
dynamic_prefix = dynamic_prefix.unsqueeze(1).expand(-1, self.num_prefix, -1)
except Exception as e:
print(f"[ERROR] sensor_mlp处理失败: {e}")
dynamic_prefix = torch.zeros_like(static_prefix)
else:
dynamic_prefix = torch.zeros_like(static_prefix)
# 混合前缀
alpha = 0.6
final_prefix = alpha * static_prefix + (1 - alpha) * dynamic_prefix
final_prefix = self.norm(final_prefix)
# 处理token嵌入
token_embeds = self.model.model.embed_tokens(input_ids)
input_embeds = torch.cat((final_prefix, token_embeds), dim=1)
# 扩展注意力掩码
prefix_attention_mask = torch.ones(
(batch_size, self.num_prefix),
dtype=attention_mask.dtype,
device=device
)
extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
# ✨ 关键修复: 处理标签
if labels is not None:
# 为前缀部分创建-100的标签表示忽略
prefix_labels = torch.full(
(batch_size, self.num_prefix),
fill_value=-100, # -100表示忽略这些位置的损失
dtype=labels.dtype,
device=device
)
# 扩展标签
extended_labels = torch.cat((prefix_labels, labels), dim=1)
else:
extended_labels = None
# 调试输出
# print(f"[DEBUG] 原始输入大小: {input_ids.shape}")
# print(f"[DEBUG] 扩展嵌入大小: {input_embeds.shape}")
# print(f"[DEBUG] 扩展掩码大小: {extended_attention_mask.shape}")
# if extended_labels is not None:
# print(f"[DEBUG] 扩展标签大小: {extended_labels.shape}")
# 确保不提供input_ids
if 'input_ids' in kwargs:
del kwargs['input_ids']
# ✨ 传递扩展后的标签
return self.model(
inputs_embeds=input_embeds,
attention_mask=extended_attention_mask,
labels=extended_labels, # 这是关键修改
use_cache=False,
**kwargs)
# class PrefixKGEmbedding(nn.Module):
# def __init__(
# self,
# num_ent,
# num_rel,
# dim_llm,
# num_prefix
# ):
# super(PrefixKGEmbedding, self).__init__()
# self.emb_dim = num_prefix * dim_llm
# self.ent_embeddings = nn.Embedding(num_ent, self.emb_dim)
# self.rel_embeddings = nn.Embedding(num_rel, self.emb_dim)
#
#
# def forward(self, triple_ids):
# head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
# h = self.ent_embeddings(head)
# r = self.rel_embeddings(relation)
# t = self.ent_embeddings(tail)
# prefix = torch.stack((h, r, t), dim=1)
# return prefix
class PretrainKGEmbedding(nn.Module):
def __init__(
self,
pretrain_ent_embs,
pretrain_rel_embs,
dim_llm,
num_prefix
):
super(PretrainKGEmbedding, self).__init__()
self.num_prefix = num_prefix
self.llm_dim = dim_llm
self.emb_dim = num_prefix * dim_llm
self.ent_embeddings = nn.Embedding.from_pretrained(pretrain_ent_embs)
self.rel_embeddings = nn.Embedding.from_pretrained(pretrain_rel_embs)
self.pretrain_dim = self.ent_embeddings.weight.shape[1]
# Froze the pretrain embeddings
self.ent_embeddings.requires_grad_(False)
self.rel_embeddings.requires_grad_(False)
self.adapter = nn.Linear(self.pretrain_dim, self.emb_dim)
def forward(self, triple_ids):
# main training stage
if triple_ids.shape[1] == 3:
head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
h = self.ent_embeddings(head)
r = self.rel_embeddings(relation)
t = self.ent_embeddings(tail)
pretrain_embs = torch.stack((h, r, t), dim=1)
prefix = self.adapter(pretrain_embs).reshape(-1, 3*self.num_prefix, self.llm_dim)
return prefix
# entity-aware pre-funing
else:
ent = triple_ids.reshape(-1,)
emb = self.ent_embeddings(ent)
prefix = self.adapter(emb).reshape(-1, self.num_prefix, self.llm_dim)
# print(prefix.shape)
return prefix