diff --git a/finetune_kopa.py b/finetune_kopa.py index 7066f3f..8e37193 100644 --- a/finetune_kopa.py +++ b/finetune_kopa.py @@ -19,6 +19,43 @@ from peft import PrefixTuningConfig, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer from utils.prompter import Prompter +import os + +os.environ["SAFETENSORS_FAST_SAVE"] = "0" +os.environ["TOKENIZERS_PARALLELISM"] = "false" # 解决 tokenizer 的 fork 报错 + + +def untie_shared_weights(model): + print("[INFO] Untying shared weights in the model...") + + # For Qwen models, we need to handle specific weight sharing patterns + if hasattr(model, "model") and hasattr(model.model, "base_model") and hasattr(model.model.base_model, "model"): + base_model = model.model.base_model.model + + # Handle the first shared weights: embed_tokens and word_embeddings + if hasattr(base_model, "embed_tokens") and hasattr(model.model, "word_embeddings"): + if id(base_model.embed_tokens.weight) == id(model.model.word_embeddings.weight): + print("[INFO] Untying shared weights between embed_tokens and word_embeddings") + # Create a new tensor with the same values + model.model.word_embeddings.weight = torch.nn.Parameter( + base_model.embed_tokens.weight.clone() + ) + + # Handle the second shared weights: embeddings and static_prefix_embedding + if hasattr(model, "embeddings") and hasattr(model, "static_prefix_embedding"): + if id(model.embeddings.weight) == id(model.static_prefix_embedding.weight): + print("[INFO] Untying shared weights between embeddings and static_prefix_embedding") + # Create a new tensor with the same values + model.static_prefix_embedding.weight = torch.nn.Parameter( + model.embeddings.weight.clone() + ) + + # Disable any tie_weights methods + if hasattr(model, "tie_weights"): + model.tie_weights = lambda: None + print("[INFO] Disabled tie_weights method") + + print("[INFO] Completed untying shared weights") def custom_collate_fn(batch): @@ -55,20 +92,18 @@ def custom_collate_fn(batch): else: sensor_data = b["sensor_data"] sensor_data_list.append(sensor_data) - max_length=0 + max_length = 0 for one_inputs in input_ids_list: max_length = one_inputs.size(0) if max_length < one_inputs.size(0) else max_length - input_ids_list_=list() + input_ids_list_ = list() for one_inputs in input_ids_list: - input_ids_list_.append(torch.cat((one_inputs, torch.full((max_length-one_inputs.size(0),), 0, dtype=torch.int)), dim=-1)) + input_ids_list_.append( + torch.cat((one_inputs, torch.full((max_length - one_inputs.size(0),), 151645, dtype=torch.int)), dim=-1)) - - attention_mask_list_=list() + attention_mask_list_ = list() for mask in attention_mask_list: - attention_mask_list_.append(torch.cat((mask, torch.full((max_length-mask.size(0),), 0, dtype=torch.int)), dim=-1)) - - # print("=====",input_ids_list) - # exit(0) + attention_mask_list_.append( + torch.cat((mask, torch.full((max_length - mask.size(0),), 0, dtype=torch.int)), dim=-1)) # 堆叠数据 result = { @@ -90,10 +125,10 @@ def custom_collate_fn(batch): else: labels = b["labels"] labels_list.append(labels) - labels_list_=list() + labels_list_ = list() for label in labels_list: - labels_list_.append(torch.cat((label, torch.full((max_length-label.size(0),), 0, dtype=torch.int)), dim=-1)) - + labels_list_.append( + torch.cat((label, torch.full((max_length - label.size(0),), 151645, dtype=torch.int)), dim=-1)) result["labels"] = torch.stack(labels_list_) @@ -109,7 +144,7 @@ def train( batch_size: int = 16, micro_batch_size: int = 16, num_epochs: int = 2, - learning_rate: float = 1e-4, + learning_rate: float = 1e-4, cutoff_len: int = 512, val_set_size: int = 0, num_prefix: int = 1, @@ -163,7 +198,7 @@ def train( model = AutoModelForCausalLM.from_pretrained( base_model, - load_in_8bit=True, + load_in_8bit=False, # 使用Auto类自动选择正确的模型类型 torch_dtype=torch.float16, device_map=device_map, @@ -177,7 +212,8 @@ def train( ) tokenizer.pad_token = tokenizer.eos_token - + # print("=====",model.config.eos_token_id) + # exit(0) # tokenizer.pad_token_id = ( # 0 # unk. we want this to be different from the eos token @@ -203,28 +239,6 @@ def train( return dataset - # def tokenize(prompt, add_eos_token=True): - # # there's probably a way to do this with the tokenizer settings - # # but again, gotta move fast - # result = tokenizer( - # prompt, - # truncation=True, - # max_length=cutoff_len, - # padding=False, - # return_tensors=None, - # ) - # if ( - # result["input_ids"][-1] != tokenizer.eos_token_id - # and len(result["input_ids"]) < cutoff_len - # and add_eos_token - # ): - # result["input_ids"].append(tokenizer.eos_token_id) - # result["attention_mask"].append(1) - # - # result["labels"] = result["input_ids"].copy() - # - # return result - def generate_and_tokenize_prompt(data_point): full_prompt = prompter.generate_prompt( data_point["instruction"], @@ -244,7 +258,6 @@ def train( # exit(0) - tokenized_full_prompt = {k: v.squeeze(0) for k, v in tokenized_full_prompt.items()} # 处理静态前缀 @@ -305,7 +318,7 @@ def train( elif isinstance(tokenized_full_prompt[key], torch.Tensor) and tokenized_full_prompt[key].dim() > 1: # Squeeze extra dimensions if needed tokenized_full_prompt[key] = tokenized_full_prompt[key].squeeze(0) - + if key in ["input_ids", "attention_mask"] and isinstance(tokenized_full_prompt[key], list): tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key], dtype=torch.long) @@ -319,7 +332,6 @@ def train( tokenized_full_prompt["labels"] = tokenized_full_prompt["input_ids"].clone() - # 如果不想对输入部分计算损失,可以将输入部分的标签设为-100 if not train_on_inputs: # 找到用户输入和助手输出的分界点 sep = tokenizer.encode(prompter.separator) @@ -331,7 +343,6 @@ def train( return tokenized_full_prompt - # 创建PrefixTuning配置 prefix_config = PrefixTuningConfig( @@ -342,7 +353,6 @@ def train( # 创建PEFT模型 peft_model = get_peft_model(model, prefix_config) - # 创建最终的KoPAWithAdapter模型 final_model = KoPAWithAdapter(peft_model, num_prefix, tokenizer) device = next(model.parameters()).device @@ -351,7 +361,6 @@ def train( # 确保final_model及其组件都在相同设备上 final_model = final_model.to(device) - if data_path.endswith(".json") or data_path.endswith(".jsonl"): data = load_dataset("json", data_files=data_path) else: @@ -400,6 +409,10 @@ def train( model.is_parallelizable = True model.model_parallel = True + untie_shared_weights(final_model) + + # For KoPAWithAdapter models, we need a custom save approach + trainer = transformers.Trainer( model=final_model, data_collator=custom_collate_fn, @@ -411,13 +424,13 @@ def train( warmup_steps=100, num_train_epochs=num_epochs, learning_rate=learning_rate, - fp16=True, + fp16=False, logging_steps=10, optim="adamw_hf", evaluation_strategy="steps" if val_set_size > 0 else "no", save_strategy="steps", eval_steps=None, - save_steps=5000, + save_steps=10, output_dir=output_dir, save_total_limit=2, load_best_model_at_end=True if val_set_size > 0 else False, @@ -432,27 +445,69 @@ def train( if torch.__version__ >= "2" and sys.platform != "win32": final_model = torch.compile(model) + untie_shared_weights(final_model) + trainer.train(resume_from_checkpoint=resume_from_checkpoint) - final_model.save_pretrained(output_dir) - - # ⭐ 确保embeddings存在再保存 - if hasattr(final_model, "embeddings"): - torch.save(final_model.embeddings, os.path.join(output_dir, "embeddings.pth")) - else: - print("[WARNING] final_model没有embeddings属性,跳过保存。") - try: - final_model.model.save_pretrained(os.path.join(output_dir, "peft_model")) - print(f"[INFO] PEFT模型保存到 {os.path.join(output_dir, 'peft_model')}") + final_model = untie_shared_weights(final_model) + print(f"[INFO] Saving model to {output_dir}") + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + # 如果是分布式训练,只在主进程保存 + if int(os.environ.get("LOCAL_RANK", 0)) == 0: + # 将模型移到CPU上保存 + model_to_save = final_model.module if hasattr(final_model, "module") else final_model + model_to_save = model_to_save.cpu() + + try: + # Save the main model components + if hasattr(final_model, "save_model"): + final_model.save_model(output_dir) + else: + # Save model configuration + if hasattr(final_model, "config"): + final_model.config.save_pretrained(output_dir) + + # Save model state dict + torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) + print(f"[INFO] Successfully saved model state dict") + + # Save embeddings separately if they exist + if hasattr(final_model, "embeddings"): + torch.save(final_model.embeddings, os.path.join(output_dir, "embeddings.pth")) + print(f"[INFO] Successfully saved embeddings") + + # Save PEFT model components + if hasattr(final_model, "model") and hasattr(final_model.model, "save_pretrained"): + peft_save_dir = os.path.join(output_dir, "peft_model") + os.makedirs(peft_save_dir, exist_ok=True) + final_model.model.save_pretrained(peft_save_dir) + print(f"[INFO] PEFT model saved to {peft_save_dir}") + + # 保存完成后将模型移回原设备 + model_to_save = model_to_save.to(device) + + except Exception as e: + print(f"[ERROR] Error during model saving: {str(e)}") + import traceback + traceback.print_exc() + raise e + except Exception as e: - print(f"[WARNING] 保存PEFT模型时出错: {e}") + print(f"[ERROR] Error in save process: {str(e)}") + import traceback + traceback.print_exc() + raise e + def inspect_model_structure(model): """检查模型结构并打印关键层信息""" print(f"Model type: {type(model).__name__}") print(f"Model config: {model.config.__class__.__name__}") - + # 检查嵌入层 embedding_layers = [] for name, module in model.named_modules(): @@ -460,11 +515,11 @@ def inspect_model_structure(model): embedding_layers.append((name, type(module).__name__)) if hasattr(module, 'weight'): print(f"Layer {name}: shape {module.weight.shape}") - + print(f"Found {len(embedding_layers)} potential embedding layers:") for name, type_name in embedding_layers: print(f" - {name}: {type_name}") - + # 检查注意力层 print("\nAttention structure:") for name, module in model.named_modules(): diff --git a/kopa.py b/kopa.py index ea362ed..13b41ab 100644 --- a/kopa.py +++ b/kopa.py @@ -7,38 +7,39 @@ from transformers import LlamaForCausalLM class KoPA(nn.Module): def __init__( - self, - model + self, + model ) -> None: super(KoPA, self).__init__() self.llama_model = model for param in self.model.parameters(): param.requires_grad = False - + # Only keep gradients for the adapter parts - self.num_prefix = num_prefix + # self.num_prefix = num_prefix hidden_size = model.config.hidden_size self.embeddings = nn.Embedding(100, 4096) for param in model.parameters(): param.requires_grad = False - + # Only enable gradients for adapter components self.static_prefix_embedding.requires_grad_(True) self.sensor_mlp.requires_grad_(True) self.norm.requires_grad_(True) + def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - embedding_ids: torch.LongTensor = None + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + embedding_ids: torch.LongTensor = None ): if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0: print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}") @@ -84,10 +85,10 @@ class KoPAWithAdapter(nn.Module): self.model = model self.num_prefix = num_prefix hidden_size = model.config.hidden_size - + # 打印模型信息以便调试 print(f"[INFO] 初始化KoPAWithAdapter,模型类型: {type(model).__name__}") - + # 使用tokenizer获取vocab_size vocab_size = tokenizer.vocab_size if tokenizer else 151936 # Qwen2.5的默认词表大小 print(f"[INFO] 使用词表大小: {vocab_size}") @@ -106,22 +107,23 @@ class KoPAWithAdapter(nn.Module): self.norm = nn.LayerNorm(hidden_size) print(f"[INFO] 模型初始化: hidden_size={hidden_size}, vocab_size={vocab_size}") - + # 检测模型嵌入层路径 self._detect_embedding_path() - + def _detect_embedding_path(self): """检测模型的嵌入层路径""" self.embedding_path = None - + # 尝试不同的常见路径 if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'): self.embedding_path = "transformer.wte" elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'): self.embedding_path = "model.embed_tokens" - elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'): + elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, + 'embed_tokens'): self.embedding_path = "model.model.model.embed_tokens" - + if self.embedding_path: print(f"[INFO] 检测到嵌入层路径: {self.embedding_path}") else: @@ -183,7 +185,8 @@ class KoPAWithAdapter(nn.Module): elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'): token_embeds = self.model.model.embed_tokens(input_ids) self.embedding_path = "model.embed_tokens" - elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'): + elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr( + self.model.model.model, 'embed_tokens'): token_embeds = self.model.model.model.embed_tokens(input_ids) self.embedding_path = "model.model.model.embed_tokens" else: @@ -226,6 +229,18 @@ class KoPAWithAdapter(nn.Module): if 'input_ids' in kwargs: del kwargs['input_ids'] + model_dtype = next(self.model.parameters()).dtype + input_embeds = input_embeds.to(dtype=model_dtype) + + # Remaining code as before... + prefix_attention_mask = torch.ones( + (batch_size, self.num_prefix), + dtype=attention_mask.dtype, + device=device + ) + extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + extended_attention_mask = extended_attention_mask.to(dtype=model_dtype) + # 传递扩展后的标签 return self.model( inputs_embeds=input_embeds, @@ -234,6 +249,7 @@ class KoPAWithAdapter(nn.Module): use_cache=False, **kwargs) + # class PrefixKGEmbedding(nn.Module): # def __init__( # self, @@ -258,11 +274,11 @@ class KoPAWithAdapter(nn.Module): class PretrainKGEmbedding(nn.Module): def __init__( - self, - pretrain_ent_embs, - pretrain_rel_embs, - dim_llm, - num_prefix + self, + pretrain_ent_embs, + pretrain_rel_embs, + dim_llm, + num_prefix ): super(PretrainKGEmbedding, self).__init__() self.num_prefix = num_prefix @@ -276,7 +292,6 @@ class PretrainKGEmbedding(nn.Module): self.rel_embeddings.requires_grad_(False) self.adapter = nn.Linear(self.pretrain_dim, self.emb_dim) - def forward(self, triple_ids): # main training stage if triple_ids.shape[1] == 3: @@ -285,11 +300,11 @@ class PretrainKGEmbedding(nn.Module): r = self.rel_embeddings(relation) t = self.ent_embeddings(tail) pretrain_embs = torch.stack((h, r, t), dim=1) - prefix = self.adapter(pretrain_embs).reshape(-1, 3*self.num_prefix, self.llm_dim) + prefix = self.adapter(pretrain_embs).reshape(-1, 3 * self.num_prefix, self.llm_dim) return prefix # entity-aware pre-funing else: - ent = triple_ids.reshape(-1,) + ent = triple_ids.reshape(-1, ) emb = self.ent_embeddings(ent) prefix = self.adapter(emb).reshape(-1, self.num_prefix, self.llm_dim) # print(prefix.shape)