kopa/.ipynb_checkpoints/finetune_kopa-checkpoint.py
2025-03-17 20:17:41 +08:00

477 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
from kopa import KoPAWithAdapter
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""
from peft import PrefixTuningConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
def custom_collate_fn(batch):
input_ids_list = []
attention_mask_list = []
static_prefix_list = []
sensor_data_list = []
# qwen_dict= {'llama_eos_tid':, 'qwen_eos_tid':}
for b in batch:
# 确保输入是张量
if isinstance(b["input_ids"], list):
input_ids = torch.tensor(b["input_ids"], dtype=torch.long)
else:
input_ids = b["input_ids"]
input_ids_list.append(input_ids)
if isinstance(b["attention_mask"], list):
attention_mask = torch.tensor(b["attention_mask"], dtype=torch.long)
else:
attention_mask = b["attention_mask"]
attention_mask_list.append(attention_mask)
if "static_prefix" in b:
if isinstance(b["static_prefix"], list):
static_prefix = torch.tensor(b["static_prefix"], dtype=torch.long)
else:
static_prefix = b["static_prefix"]
static_prefix_list.append(static_prefix)
if "sensor_data" in b:
if isinstance(b["sensor_data"], list):
sensor_data = torch.tensor(b["sensor_data"], dtype=torch.float)
else:
sensor_data = b["sensor_data"]
sensor_data_list.append(sensor_data)
max_length=0
for one_inputs in input_ids_list:
max_length = one_inputs.size(0) if max_length < one_inputs.size(0) else max_length
input_ids_list_=list()
for one_inputs in input_ids_list:
input_ids_list_.append(torch.cat((one_inputs, torch.full((max_length-one_inputs.size(0),), 0, dtype=torch.int)), dim=-1))
attention_mask_list_=list()
for mask in attention_mask_list:
attention_mask_list_.append(torch.cat((mask, torch.full((max_length-mask.size(0),), 0, dtype=torch.int)), dim=-1))
# print("=====",input_ids_list)
# exit(0)
# 堆叠数据
result = {
"input_ids": torch.stack(input_ids_list_),
"attention_mask": torch.stack(attention_mask_list_),
}
if static_prefix_list:
result["static_prefix"] = torch.stack(static_prefix_list)
if sensor_data_list:
result["sensor_data"] = torch.stack(sensor_data_list)
if "labels" in batch[0]:
labels_list = []
for b in batch:
if isinstance(b["labels"], list):
labels = torch.tensor(b["labels"], dtype=torch.long)
else:
labels = b["labels"]
labels_list.append(labels)
labels_list_=list()
for label in labels_list:
labels_list_.append(torch.cat((label, torch.full((max_length-label.size(0),), 0, dtype=torch.int)), dim=-1))
result["labels"] = torch.stack(labels_list_)
return result
def train(
# model/data params
base_model="/root/shared-nvme/models/Qwen2.5-7B-Instruct",
data_path: str = "/root/shared-nvme/dataset/olive_dataset.json",
output_dir: str = "output",
# training hyperparams
batch_size: int = 16,
micro_batch_size: int = 16,
num_epochs: int = 2,
learning_rate: float = 1e-4,
cutoff_len: int = 512,
val_set_size: int = 0,
num_prefix: int = 1,
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
# 使用Auto类自动选择正确的模型类型
torch_dtype=torch.float16,
device_map=device_map,
trust_remote_code=True, # Qwen模型需要此参数
)
tokenizer = AutoTokenizer.from_pretrained(
base_model,
trust_remote_code=True, # 添加此参数
padding_side="left", # Qwen也推荐左侧填充
)
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.pad_token_id = (
# 0 # unk. we want this to be different from the eos token
# )
# tokenizer.padding_side = "left" # Allow batched inference
# model.gradient_checkpointing_enable()
# tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.generation_config.pad_token_id = model.generation_config.eos_token_id
def ensure_consistent_keys(dataset):
all_keys = set()
for example in dataset:
all_keys.update(example.keys())
for example in dataset:
for key in all_keys:
if key not in example:
if key == "static_prefix":
example[key] = ""
elif key == "sensor_data":
example[key] = [0, 0, 0]
return dataset
# def tokenize(prompt, add_eos_token=True):
# # there's probably a way to do this with the tokenizer settings
# # but again, gotta move fast
# result = tokenizer(
# prompt,
# truncation=True,
# max_length=cutoff_len,
# padding=False,
# return_tensors=None,
# )
# if (
# result["input_ids"][-1] != tokenizer.eos_token_id
# and len(result["input_ids"]) < cutoff_len
# and add_eos_token
# ):
# result["input_ids"].append(tokenizer.eos_token_id)
# result["attention_mask"].append(1)
#
# result["labels"] = result["input_ids"].copy()
#
# return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
# Tokenizer 处理文本
tokenized_full_prompt = tokenizer(
full_prompt,
truncation=True,
max_length=cutoff_len,
padding=True,
return_tensors='pt',
)
# for k,v in tokenized_full_prompt.items(): print("======k,v",k,v,type(k),type(v))
# exit(0)
tokenized_full_prompt = {k: v.squeeze(0) for k, v in tokenized_full_prompt.items()}
# 处理静态前缀
static_prefix = tokenizer(
data_point["instruction"],
truncation=True,
max_length=10,
padding="max_length",
return_tensors="pt"
)["input_ids"].squeeze(0)
# 限制索引范围,确保 `static_prefix` 不会超出 `vocab_size`
static_prefix = torch.clamp(static_prefix, min=0, max=tokenizer.vocab_size - 1)
tokenized_full_prompt["static_prefix"] = static_prefix
# print(f"[DEBUG] static_prefix (after clamp): {static_prefix}")
# print(f"[DEBUG] tokenizer vocab_size: {tokenizer.vocab_size}")
# **处理动态数据**
sensor_values = torch.zeros(3, dtype=torch.float) # **默认值为 Tensor而不是 list**
if data_point["type"] == "dynamic" and "sensor_data" in data_point:
raw_sensor_values = data_point["sensor_data"]
try:
sensor_values = torch.tensor([
float(raw_sensor_values.get("temperature", 0.0)),
float(raw_sensor_values.get("humidity", 0.0)),
float(raw_sensor_values.get("conductivity", 0.0))
], dtype=torch.float)
except Exception as e:
# print(f"[ERROR] sensor_data 解析错误: {raw_sensor_values}, {e}")
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
# print(f"[ERROR] NaN/Inf detected in sensor_values: {sensor_values}")
sensor_values = torch.zeros(3, dtype=torch.float)
# ✅ 确保 sensor_values 是 `Tensor`
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
print(f"[ERROR] NaN/Inf detected in sensor_values")
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
print(f"[ERROR] NaN/Inf detected in sensor_values")
sensor_values = torch.zeros(3, dtype=torch.float)
# 限制范围,防止异常值
sensor_values = torch.clamp(sensor_values, min=-100, max=100)
# print(f"[DEBUG] sensor_values (AFTER FIX): {sensor_values}") # 🔥 打印调试信息
if not isinstance(sensor_values, torch.Tensor):
sensor_values = torch.tensor(sensor_values, dtype=torch.float)
tokenized_full_prompt["sensor_data"] = sensor_values # **确保始终是 Tensor**
# 最后增加类型检查和转换
for key in tokenized_full_prompt:
if isinstance(tokenized_full_prompt[key], list):
# Convert lists to tensors
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key])
elif isinstance(tokenized_full_prompt[key], torch.Tensor) and tokenized_full_prompt[key].dim() > 1:
# Squeeze extra dimensions if needed
tokenized_full_prompt[key] = tokenized_full_prompt[key].squeeze(0)
if key in ["input_ids", "attention_mask"] and isinstance(tokenized_full_prompt[key], list):
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key], dtype=torch.long)
if isinstance(tokenized_full_prompt["static_prefix"], list):
tokenized_full_prompt["static_prefix"] = torch.tensor(tokenized_full_prompt["static_prefix"],
dtype=torch.long)
# 确保sensor_data是tensor
if not isinstance(tokenized_full_prompt["sensor_data"], torch.Tensor):
tokenized_full_prompt["sensor_data"] = torch.tensor(tokenized_full_prompt["sensor_data"], dtype=torch.float)
tokenized_full_prompt["labels"] = tokenized_full_prompt["input_ids"].clone()
# 如果不想对输入部分计算损失,可以将输入部分的标签设为-100
if not train_on_inputs:
# 找到用户输入和助手输出的分界点
sep = tokenizer.encode(prompter.separator)
instruction_tokens = tokenizer.encode(data_point["instruction"])
# 将用户输入部分的标签设为-100
sep_pos = tokenized_full_prompt["input_ids"].tolist().index(sep[0])
tokenized_full_prompt["labels"][:sep_pos] = -100
return tokenized_full_prompt
# 创建PrefixTuning配置
prefix_config = PrefixTuningConfig(
num_virtual_tokens=num_prefix,
task_type="CAUSAL_LM"
)
# 创建PEFT模型
peft_model = get_peft_model(model, prefix_config)
# 创建最终的KoPAWithAdapter模型
final_model = KoPAWithAdapter(peft_model, num_prefix, tokenizer)
device = next(model.parameters()).device
print(f"[INFO] 使用设备: {device}")
# 确保final_model及其组件都在相同设备上
final_model = final_model.to(device)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
else:
print(f"Checkpoint {checkpoint_name} not found")
# model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
train_data = ensure_consistent_keys(train_data)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
train_data = ensure_consistent_keys(train_data)
val_data = None
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=final_model,
data_collator=custom_collate_fn,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
optim="adamw_hf",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=None,
save_steps=5000,
output_dir=output_dir,
save_total_limit=2,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to=None,
run_name=None,
),
)
# final_model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
final_model = torch.compile(model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
final_model.save_pretrained(output_dir)
# ⭐ 确保embeddings存在再保存
if hasattr(final_model, "embeddings"):
torch.save(final_model.embeddings, os.path.join(output_dir, "embeddings.pth"))
else:
print("[WARNING] final_model没有embeddings属性跳过保存。")
try:
final_model.model.save_pretrained(os.path.join(output_dir, "peft_model"))
print(f"[INFO] PEFT模型保存到 {os.path.join(output_dir, 'peft_model')}")
except Exception as e:
print(f"[WARNING] 保存PEFT模型时出错: {e}")
def inspect_model_structure(model):
"""检查模型结构并打印关键层信息"""
print(f"Model type: {type(model).__name__}")
print(f"Model config: {model.config.__class__.__name__}")
# 检查嵌入层
embedding_layers = []
for name, module in model.named_modules():
if any(key in name for key in ['embed', 'wte', 'word_embeddings']):
embedding_layers.append((name, type(module).__name__))
if hasattr(module, 'weight'):
print(f"Layer {name}: shape {module.weight.shape}")
print(f"Found {len(embedding_layers)} potential embedding layers:")
for name, type_name in embedding_layers:
print(f" - {name}: {type_name}")
# 检查注意力层
print("\nAttention structure:")
for name, module in model.named_modules():
if 'attention' in name.lower():
print(f" - {name}: {type(module).__name__}")
if __name__ == "__main__":
fire.Fire(train)