kopa/finetune_kopa.py
2025-03-27 13:54:14 +08:00

514 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
from kopa import KoPAWithAdapter
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""
from peft import PrefixTuningConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
import os
os.environ["SAFETENSORS_FAST_SAVE"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false" # 解决 tokenizer 的 fork 报错
def untie_shared_weights(model):
print("[INFO] Untying shared weights in the model...")
# For Qwen models, we need to handle specific weight sharing patterns
if hasattr(model, "model") and hasattr(model.model, "base_model") and hasattr(model.model.base_model, "model"):
base_model = model.model.base_model.model
# Handle the first shared weights: embed_tokens and word_embeddings
if hasattr(base_model, "embed_tokens") and hasattr(model.model, "word_embeddings"):
if id(base_model.embed_tokens.weight) == id(model.model.word_embeddings.weight):
print("[INFO] Untying shared weights between embed_tokens and word_embeddings")
# Create a new tensor with the same values
model.model.word_embeddings.weight = torch.nn.Parameter(
base_model.embed_tokens.weight.clone()
)
# Handle the second shared weights: embeddings and static_prefix_embedding
if hasattr(model, "embeddings") and hasattr(model, "static_prefix_embedding"):
if id(model.embeddings.weight) == id(model.static_prefix_embedding.weight):
print("[INFO] Untying shared weights between embeddings and static_prefix_embedding")
# Create a new tensor with the same values
model.static_prefix_embedding.weight = torch.nn.Parameter(
model.embeddings.weight.clone()
)
# Disable any tie_weights methods
if hasattr(model, "tie_weights"):
model.tie_weights = lambda: None
print("[INFO] Disabled tie_weights method")
print("[INFO] Completed untying shared weights")
def custom_collate_fn(batch):
input_ids_list = []
attention_mask_list = []
static_prefix_list = []
sensor_data_list = []
# qwen_dict= {'llama_eos_tid':, 'qwen_eos_tid':}
for b in batch:
# 确保输入是张量
if isinstance(b["input_ids"], list):
input_ids = torch.tensor(b["input_ids"], dtype=torch.long)
else:
input_ids = b["input_ids"]
input_ids_list.append(input_ids)
if isinstance(b["attention_mask"], list):
attention_mask = torch.tensor(b["attention_mask"], dtype=torch.long)
else:
attention_mask = b["attention_mask"]
attention_mask_list.append(attention_mask)
if "static_prefix" in b:
if isinstance(b["static_prefix"], list):
static_prefix = torch.tensor(b["static_prefix"], dtype=torch.long)
else:
static_prefix = b["static_prefix"]
static_prefix_list.append(static_prefix)
if "sensor_data" in b:
if isinstance(b["sensor_data"], list):
sensor_data = torch.tensor(b["sensor_data"], dtype=torch.float)
else:
sensor_data = b["sensor_data"]
sensor_data_list.append(sensor_data)
max_length = 0
for one_inputs in input_ids_list:
max_length = one_inputs.size(0) if max_length < one_inputs.size(0) else max_length
input_ids_list_ = list()
for one_inputs in input_ids_list:
input_ids_list_.append(
torch.cat((one_inputs, torch.full((max_length - one_inputs.size(0),), 151645, dtype=torch.int)), dim=-1))
attention_mask_list_ = list()
for mask in attention_mask_list:
attention_mask_list_.append(
torch.cat((mask, torch.full((max_length - mask.size(0),), 0, dtype=torch.int)), dim=-1))
# 堆叠数据
result = {
"input_ids": torch.stack(input_ids_list_),
"attention_mask": torch.stack(attention_mask_list_),
}
if static_prefix_list:
result["static_prefix"] = torch.stack(static_prefix_list)
if sensor_data_list:
result["sensor_data"] = torch.stack(sensor_data_list)
if "labels" in batch[0]:
labels_list = []
for b in batch:
if isinstance(b["labels"], list):
labels = torch.tensor(b["labels"], dtype=torch.long)
else:
labels = b["labels"]
labels_list.append(labels)
labels_list_ = list()
for label in labels_list:
labels_list_.append(
torch.cat((label, torch.full((max_length - label.size(0),), 151645, dtype=torch.int)), dim=-1))
result["labels"] = torch.stack(labels_list_)
return result
def train(
# model/data params
base_model="/root/shared-nvme/models/Qwen2.5-7B-Instruct",
data_path: str = "/root/shared-nvme/dataset/olive_dataset.json",
output_dir: str = "output",
# training hyperparams
batch_size: int = 8,
micro_batch_size: int = 4,
num_epochs: int = 2,
learning_rate: float = 1e-5,
cutoff_len: int = 512,
val_set_size: int = 0,
num_prefix: int = 10,
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=False,
# 使用Auto类自动选择正确的模型类型
torch_dtype=torch.float16,
device_map=device_map,
trust_remote_code=True, # Qwen模型需要此参数
)
tokenizer = AutoTokenizer.from_pretrained(
base_model,
trust_remote_code=True, # 添加此参数
padding_side="left", # Qwen也推荐左侧填充
)
tokenizer.pad_token = tokenizer.eos_token
# print("=====",model.config.eos_token_id)
# exit(0)
# tokenizer.pad_token_id = (
# 0 # unk. we want this to be different from the eos token
# )
# tokenizer.padding_side = "left" # Allow batched inference
# model.gradient_checkpointing_enable()
# tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.generation_config.pad_token_id = model.generation_config.eos_token_id
def ensure_consistent_keys(dataset):
all_keys = set()
for example in dataset:
all_keys.update(example.keys())
for example in dataset:
for key in all_keys:
if key not in example:
if key == "static_prefix":
example[key] = ""
elif key == "sensor_data":
example[key] = [0, 0, 0]
return dataset
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
# Tokenizer 处理文本
tokenized_full_prompt = tokenizer(
full_prompt,
truncation=True,
max_length=cutoff_len,
padding=True,
return_tensors='pt',
)
# for k,v in tokenized_full_prompt.items(): print("======k,v",k,v,type(k),type(v))
# exit(0)
tokenized_full_prompt = {k: v.squeeze(0) for k, v in tokenized_full_prompt.items()}
# 处理静态前缀
static_prefix = tokenizer(
data_point["instruction"],
truncation=True,
max_length=10,
padding="max_length",
return_tensors="pt"
)["input_ids"].squeeze(0)
# 限制索引范围,确保 `static_prefix` 不会超出 `vocab_size`
static_prefix = torch.clamp(static_prefix, min=0, max=tokenizer.vocab_size - 1)
tokenized_full_prompt["static_prefix"] = static_prefix
# print(f"[DEBUG] static_prefix (after clamp): {static_prefix}")
# print(f"[DEBUG] tokenizer vocab_size: {tokenizer.vocab_size}")
# **处理动态数据**
sensor_values = torch.zeros(3, dtype=torch.float) # **默认值为 Tensor而不是 list**
if data_point["type"] == "dynamic" and "sensor_data" in data_point:
raw_sensor_values = data_point["sensor_data"]
try:
sensor_values = torch.tensor([
float(raw_sensor_values.get("temperature", 0.0)),
float(raw_sensor_values.get("humidity", 0.0)),
float(raw_sensor_values.get("conductivity", 0.0))
], dtype=torch.float)
except Exception as e:
# print(f"[ERROR] sensor_data 解析错误: {raw_sensor_values}, {e}")
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
# print(f"[ERROR] NaN/Inf detected in sensor_values: {sensor_values}")
sensor_values = torch.zeros(3, dtype=torch.float)
# ✅ 确保 sensor_values 是 `Tensor`
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
print(f"[ERROR] NaN/Inf detected in sensor_values")
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
print(f"[ERROR] NaN/Inf detected in sensor_values")
sensor_values = torch.zeros(3, dtype=torch.float)
# 限制范围,防止异常值
sensor_values = torch.clamp(sensor_values, min=-100, max=100)
# print(f"[DEBUG] sensor_values (AFTER FIX): {sensor_values}") # 🔥 打印调试信息
if not isinstance(sensor_values, torch.Tensor):
sensor_values = torch.tensor(sensor_values, dtype=torch.float)
tokenized_full_prompt["sensor_data"] = sensor_values # **确保始终是 Tensor**
# 最后增加类型检查和转换
for key in tokenized_full_prompt:
if isinstance(tokenized_full_prompt[key], list):
# Convert lists to tensors
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key])
elif isinstance(tokenized_full_prompt[key], torch.Tensor) and tokenized_full_prompt[key].dim() > 1:
# Squeeze extra dimensions if needed
tokenized_full_prompt[key] = tokenized_full_prompt[key].squeeze(0)
if key in ["input_ids", "attention_mask"] and isinstance(tokenized_full_prompt[key], list):
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key], dtype=torch.long)
if isinstance(tokenized_full_prompt["static_prefix"], list):
tokenized_full_prompt["static_prefix"] = torch.tensor(tokenized_full_prompt["static_prefix"],
dtype=torch.long)
# 确保sensor_data是tensor
if not isinstance(tokenized_full_prompt["sensor_data"], torch.Tensor):
tokenized_full_prompt["sensor_data"] = torch.tensor(tokenized_full_prompt["sensor_data"], dtype=torch.float)
tokenized_full_prompt["labels"] = tokenized_full_prompt["input_ids"].clone()
if not train_on_inputs:
# 找到用户输入和助手输出的分界点
sep = tokenizer.encode(prompter.separator)
instruction_tokens = tokenizer.encode(data_point["instruction"])
# 将用户输入部分的标签设为-100
sep_pos = tokenized_full_prompt["input_ids"].tolist().index(sep[0])
tokenized_full_prompt["labels"][:sep_pos] = -100
return tokenized_full_prompt
# 创建PrefixTuning配置
prefix_config = PrefixTuningConfig(
num_virtual_tokens=num_prefix,
task_type="CAUSAL_LM"
)
# 创建PEFT模型
peft_model = get_peft_model(model, prefix_config)
# 创建最终的KoPAWithAdapter模型
final_model = KoPAWithAdapter(peft_model, num_prefix, tokenizer)
device = next(model.parameters()).device
print(f"[INFO] 使用设备: {device}")
# 确保final_model及其组件都在相同设备上
final_model = final_model.to(device)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
else:
print(f"Checkpoint {checkpoint_name} not found")
# model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
train_data = ensure_consistent_keys(train_data)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
train_data = ensure_consistent_keys(train_data)
val_data = None
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
untie_shared_weights(final_model)
# For KoPAWithAdapter models, we need a custom save approach
trainer = transformers.Trainer(
model=final_model,
data_collator=custom_collate_fn,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=False,
logging_steps=10,
optim="adamw_hf",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=None,
save_steps=10,
output_dir=output_dir,
save_total_limit=2,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to=None,
run_name=None,
),
)
# final_model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
final_model = torch.compile(model)
untie_shared_weights(final_model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
try:
final_model = untie_shared_weights(final_model)
print(f"[INFO] Saving model to {output_dir}")
# Save the main model components
if hasattr(final_model, "save_model"):
final_model.save_model(output_dir)
else:
# Fallback if save_model method doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Save model configuration
if hasattr(final_model, "config"):
final_model.config.save_pretrained(output_dir)
# Save model state dict (avoiding shared weights)
model_to_save = final_model.module if hasattr(final_model, "module") else final_model
torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
# Save embeddings separately if they exist
if hasattr(final_model, "embeddings"):
torch.save(final_model.embeddings, os.path.join(output_dir, "embeddings.pth"))
# Save PEFT model components
if hasattr(final_model, "model") and hasattr(final_model.model, "save_pretrained"):
peft_save_dir = os.path.join(output_dir, "peft_model")
os.makedirs(peft_save_dir, exist_ok=True)
final_model.model.save_pretrained(peft_save_dir)
print(f"[INFO] PEFT model saved to {peft_save_dir}")
except Exception as e:
print(f"[ERROR] Error saving model: {e}")
import traceback
traceback.print_exc()
def inspect_model_structure(model):
"""检查模型结构并打印关键层信息"""
print(f"Model type: {type(model).__name__}")
print(f"Model config: {model.config.__class__.__name__}")
# 检查嵌入层
embedding_layers = []
for name, module in model.named_modules():
if any(key in name for key in ['embed', 'wte', 'word_embeddings']):
embedding_layers.append((name, type(module).__name__))
if hasattr(module, 'weight'):
print(f"Layer {name}: shape {module.weight.shape}")
print(f"Found {len(embedding_layers)} potential embedding layers:")
for name, type_name in embedding_layers:
print(f" - {name}: {type_name}")
# 检查注意力层
print("\nAttention structure:")
for name, module in model.named_modules():
if 'attention' in name.lower():
print(f" - {name}: {type(module).__name__}")
if __name__ == "__main__":
fire.Fire(train)