import torch import torch.nn as nn from typing import Optional, List, Union, Tuple from transformers import LlamaForCausalLM class KoPA(nn.Module): def __init__( self, model ) -> None: super(KoPA, self).__init__() self.llama_model = model for param in self.model.parameters(): param.requires_grad = False # Only keep gradients for the adapter parts self.num_prefix = num_prefix hidden_size = model.config.hidden_size self.embeddings = nn.Embedding(100, 4096) for param in model.parameters(): param.requires_grad = False # Only enable gradients for adapter components self.static_prefix_embedding.requires_grad_(True) self.sensor_mlp.requires_grad_(True) self.norm.requires_grad_(True) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, embedding_ids: torch.LongTensor = None ): if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0: print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}") embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1) kg_embeds = self.embeddings(embedding_ids) batch_size, seq_len, _ = kg_embeds.shape if hasattr(self.llama_model, 'transformer'): # Qwen模型 token_embeds = self.llama_model.transformer.wte(input_ids) elif hasattr(self.llama_model, 'model') and hasattr(self.llama_model.model, 'embed_tokens'): # 原始路径 token_embeds = self.llama_model.model.model.embed_tokens(input_ids) else: # 添加调试代码 print("无法找到模型嵌入层,尝试检测模型结构...") raise ValueError("模型结构不兼容") input_embeds = torch.cat((kg_embeds, token_embeds), dim=1) prefix_mask = torch.ones((batch_size, seq_len)) prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long) new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1) new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1) if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0: print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}") embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1) return self.llama_model( input_ids=None, attention_mask=new_attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=input_embeds, labels=new_labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class KoPAWithAdapter(nn.Module): def __init__(self, model, num_prefix, tokenizer=None): super().__init__() self.model = model self.num_prefix = num_prefix hidden_size = model.config.hidden_size # 打印模型信息以便调试 print(f"[INFO] 初始化KoPAWithAdapter,模型类型: {type(model).__name__}") # 使用tokenizer获取vocab_size vocab_size = tokenizer.vocab_size if tokenizer else 151936 # Qwen2.5的默认词表大小 print(f"[INFO] 使用词表大小: {vocab_size}") self.static_prefix_embedding = nn.Embedding(vocab_size, hidden_size) self.embeddings = self.static_prefix_embedding # 保留这个属性 self.sensor_mlp = nn.Sequential( nn.Linear(3, hidden_size // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size // 2, hidden_size) ) # 添加LayerNorm self.norm = nn.LayerNorm(hidden_size) print(f"[INFO] 模型初始化: hidden_size={hidden_size}, vocab_size={vocab_size}") # 检测模型嵌入层路径 self._detect_embedding_path() def _detect_embedding_path(self): """检测模型的嵌入层路径""" self.embedding_path = None # 尝试不同的常见路径 if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'): self.embedding_path = "transformer.wte" elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'): self.embedding_path = "model.embed_tokens" elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'): self.embedding_path = "model.model.model.embed_tokens" if self.embedding_path: print(f"[INFO] 检测到嵌入层路径: {self.embedding_path}") else: print("[WARNING] 无法自动检测嵌入层路径,将在前向传播中尝试多种路径") def forward(self, input_ids, attention_mask, static_prefix=None, sensor_data=None, labels=None, **kwargs): batch_size, seq_len = input_ids.shape device = input_ids.device # 确保所有组件在同一设备上 self.static_prefix_embedding = self.static_prefix_embedding.to(device) self.sensor_mlp = self.sensor_mlp.to(device) self.norm = self.norm.to(device) # 处理静态前缀 if static_prefix is not None: static_prefix = static_prefix.to(device) static_prefix = self.static_prefix_embedding(static_prefix) else: static_prefix = torch.zeros( (batch_size, self.num_prefix, self.model.config.hidden_size), device=device ) # 处理动态前缀 if sensor_data is not None: sensor_data = sensor_data.to(device) if sensor_data.dim() == 1: sensor_data = sensor_data.unsqueeze(0) try: dynamic_prefix = self.sensor_mlp(sensor_data) dynamic_prefix = dynamic_prefix.unsqueeze(1).expand(-1, self.num_prefix, -1) except Exception as e: print(f"[ERROR] sensor_mlp处理失败: {e}") dynamic_prefix = torch.zeros_like(static_prefix) else: dynamic_prefix = torch.zeros_like(static_prefix) # 混合前缀 alpha = 0.6 final_prefix = alpha * static_prefix + (1 - alpha) * dynamic_prefix final_prefix = self.norm(final_prefix) # 处理token嵌入 - 根据检测到的路径获取嵌入 try: if self.embedding_path == "transformer.wte": token_embeds = self.model.transformer.wte(input_ids) elif self.embedding_path == "model.embed_tokens": token_embeds = self.model.model.embed_tokens(input_ids) elif self.embedding_path == "model.model.model.embed_tokens": token_embeds = self.model.model.model.embed_tokens(input_ids) else: # 尝试多种可能的路径 if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'): token_embeds = self.model.transformer.wte(input_ids) self.embedding_path = "transformer.wte" elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'): token_embeds = self.model.model.embed_tokens(input_ids) self.embedding_path = "model.embed_tokens" elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr( self.model.model.model, 'embed_tokens'): token_embeds = self.model.model.model.embed_tokens(input_ids) self.embedding_path = "model.model.model.embed_tokens" else: raise ValueError("无法找到嵌入层路径") print(f"[INFO] 成功找到嵌入层路径: {self.embedding_path}") except Exception as e: print(f"[ERROR] 获取token嵌入失败: {e}") # 打印模型结构以帮助调试 print("模型结构:") for name, _ in self.model.named_modules(): if 'embed' in name or 'wte' in name: print(f" - {name}") raise input_embeds = torch.cat((final_prefix, token_embeds), dim=1) # 扩展注意力掩码 prefix_attention_mask = torch.ones( (batch_size, self.num_prefix), dtype=attention_mask.dtype, device=device ) extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) # 处理标签 if labels is not None: # 为前缀部分创建-100的标签(表示忽略) prefix_labels = torch.full( (batch_size, self.num_prefix), fill_value=-100, # -100表示忽略这些位置的损失 dtype=labels.dtype, device=device ) # 扩展标签 extended_labels = torch.cat((prefix_labels, labels), dim=1) else: extended_labels = None # 确保不提供input_ids if 'input_ids' in kwargs: del kwargs['input_ids'] model_dtype = next(self.model.parameters()).dtype input_embeds = input_embeds.to(dtype=model_dtype) # Remaining code as before... prefix_attention_mask = torch.ones( (batch_size, self.num_prefix), dtype=attention_mask.dtype, device=device ) extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) extended_attention_mask = extended_attention_mask.to(dtype=model_dtype) # 传递扩展后的标签 return self.model( inputs_embeds=input_embeds, attention_mask=extended_attention_mask, labels=extended_labels, use_cache=False, **kwargs) # class PrefixKGEmbedding(nn.Module): # def __init__( # self, # num_ent, # num_rel, # dim_llm, # num_prefix # ): # super(PrefixKGEmbedding, self).__init__() # self.emb_dim = num_prefix * dim_llm # self.ent_embeddings = nn.Embedding(num_ent, self.emb_dim) # self.rel_embeddings = nn.Embedding(num_rel, self.emb_dim) # # # def forward(self, triple_ids): # head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2] # h = self.ent_embeddings(head) # r = self.rel_embeddings(relation) # t = self.ent_embeddings(tail) # prefix = torch.stack((h, r, t), dim=1) # return prefix class PretrainKGEmbedding(nn.Module): def __init__( self, pretrain_ent_embs, pretrain_rel_embs, dim_llm, num_prefix ): super(PretrainKGEmbedding, self).__init__() self.num_prefix = num_prefix self.llm_dim = dim_llm self.emb_dim = num_prefix * dim_llm self.ent_embeddings = nn.Embedding.from_pretrained(pretrain_ent_embs) self.rel_embeddings = nn.Embedding.from_pretrained(pretrain_rel_embs) self.pretrain_dim = self.ent_embeddings.weight.shape[1] # Froze the pretrain embeddings self.ent_embeddings.requires_grad_(False) self.rel_embeddings.requires_grad_(False) self.adapter = nn.Linear(self.pretrain_dim, self.emb_dim) def forward(self, triple_ids): # main training stage if triple_ids.shape[1] == 3: head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2] h = self.ent_embeddings(head) r = self.rel_embeddings(relation) t = self.ent_embeddings(tail) pretrain_embs = torch.stack((h, r, t), dim=1) prefix = self.adapter(pretrain_embs).reshape(-1, 3 * self.num_prefix, self.llm_dim) return prefix # entity-aware pre-funing else: ent = triple_ids.reshape(-1, ) emb = self.ent_embeddings(ent) prefix = self.adapter(emb).reshape(-1, self.num_prefix, self.llm_dim) # print(prefix.shape) return prefix