477 lines
18 KiB
Python
477 lines
18 KiB
Python
|
import os
|
|||
|
import sys
|
|||
|
from typing import List
|
|||
|
|
|||
|
import fire
|
|||
|
import torch
|
|||
|
import transformers
|
|||
|
from datasets import load_dataset
|
|||
|
|
|||
|
from kopa import KoPAWithAdapter
|
|||
|
|
|||
|
"""
|
|||
|
Unused imports:
|
|||
|
import torch.nn as nn
|
|||
|
import bitsandbytes as bnb
|
|||
|
"""
|
|||
|
|
|||
|
from peft import PrefixTuningConfig, get_peft_model
|
|||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
|
|||
|
from utils.prompter import Prompter
|
|||
|
|
|||
|
|
|||
|
def custom_collate_fn(batch):
|
|||
|
input_ids_list = []
|
|||
|
attention_mask_list = []
|
|||
|
static_prefix_list = []
|
|||
|
sensor_data_list = []
|
|||
|
# qwen_dict= {'llama_eos_tid':, 'qwen_eos_tid':}
|
|||
|
|
|||
|
for b in batch:
|
|||
|
# 确保输入是张量
|
|||
|
if isinstance(b["input_ids"], list):
|
|||
|
input_ids = torch.tensor(b["input_ids"], dtype=torch.long)
|
|||
|
else:
|
|||
|
input_ids = b["input_ids"]
|
|||
|
input_ids_list.append(input_ids)
|
|||
|
|
|||
|
if isinstance(b["attention_mask"], list):
|
|||
|
attention_mask = torch.tensor(b["attention_mask"], dtype=torch.long)
|
|||
|
else:
|
|||
|
attention_mask = b["attention_mask"]
|
|||
|
attention_mask_list.append(attention_mask)
|
|||
|
|
|||
|
if "static_prefix" in b:
|
|||
|
if isinstance(b["static_prefix"], list):
|
|||
|
static_prefix = torch.tensor(b["static_prefix"], dtype=torch.long)
|
|||
|
else:
|
|||
|
static_prefix = b["static_prefix"]
|
|||
|
static_prefix_list.append(static_prefix)
|
|||
|
|
|||
|
if "sensor_data" in b:
|
|||
|
if isinstance(b["sensor_data"], list):
|
|||
|
sensor_data = torch.tensor(b["sensor_data"], dtype=torch.float)
|
|||
|
else:
|
|||
|
sensor_data = b["sensor_data"]
|
|||
|
sensor_data_list.append(sensor_data)
|
|||
|
max_length=0
|
|||
|
for one_inputs in input_ids_list:
|
|||
|
max_length = one_inputs.size(0) if max_length < one_inputs.size(0) else max_length
|
|||
|
input_ids_list_=list()
|
|||
|
for one_inputs in input_ids_list:
|
|||
|
input_ids_list_.append(torch.cat((one_inputs, torch.full((max_length-one_inputs.size(0),), 0, dtype=torch.int)), dim=-1))
|
|||
|
|
|||
|
|
|||
|
attention_mask_list_=list()
|
|||
|
for mask in attention_mask_list:
|
|||
|
attention_mask_list_.append(torch.cat((mask, torch.full((max_length-mask.size(0),), 0, dtype=torch.int)), dim=-1))
|
|||
|
|
|||
|
# print("=====",input_ids_list)
|
|||
|
# exit(0)
|
|||
|
|
|||
|
# 堆叠数据
|
|||
|
result = {
|
|||
|
"input_ids": torch.stack(input_ids_list_),
|
|||
|
"attention_mask": torch.stack(attention_mask_list_),
|
|||
|
}
|
|||
|
|
|||
|
if static_prefix_list:
|
|||
|
result["static_prefix"] = torch.stack(static_prefix_list)
|
|||
|
|
|||
|
if sensor_data_list:
|
|||
|
result["sensor_data"] = torch.stack(sensor_data_list)
|
|||
|
|
|||
|
if "labels" in batch[0]:
|
|||
|
labels_list = []
|
|||
|
for b in batch:
|
|||
|
if isinstance(b["labels"], list):
|
|||
|
labels = torch.tensor(b["labels"], dtype=torch.long)
|
|||
|
else:
|
|||
|
labels = b["labels"]
|
|||
|
labels_list.append(labels)
|
|||
|
labels_list_=list()
|
|||
|
for label in labels_list:
|
|||
|
labels_list_.append(torch.cat((label, torch.full((max_length-label.size(0),), 0, dtype=torch.int)), dim=-1))
|
|||
|
|
|||
|
|
|||
|
result["labels"] = torch.stack(labels_list_)
|
|||
|
|
|||
|
return result
|
|||
|
|
|||
|
|
|||
|
def train(
|
|||
|
# model/data params
|
|||
|
base_model="/root/shared-nvme/models/Qwen2.5-7B-Instruct",
|
|||
|
data_path: str = "/root/shared-nvme/dataset/olive_dataset.json",
|
|||
|
output_dir: str = "output",
|
|||
|
# training hyperparams
|
|||
|
batch_size: int = 16,
|
|||
|
micro_batch_size: int = 16,
|
|||
|
num_epochs: int = 2,
|
|||
|
learning_rate: float = 1e-4,
|
|||
|
cutoff_len: int = 512,
|
|||
|
val_set_size: int = 0,
|
|||
|
num_prefix: int = 1,
|
|||
|
# llm hyperparams
|
|||
|
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
|||
|
add_eos_token: bool = False,
|
|||
|
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
|||
|
# wandb params
|
|||
|
wandb_project: str = "",
|
|||
|
wandb_run_name: str = "",
|
|||
|
wandb_watch: str = "", # options: false | gradients | all
|
|||
|
wandb_log_model: str = "", # options: false | true
|
|||
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|||
|
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
|||
|
):
|
|||
|
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
|||
|
print(
|
|||
|
f"Training Alpaca model with params:\n"
|
|||
|
f"base_model: {base_model}\n"
|
|||
|
f"data_path: {data_path}\n"
|
|||
|
f"output_dir: {output_dir}\n"
|
|||
|
f"batch_size: {batch_size}\n"
|
|||
|
f"micro_batch_size: {micro_batch_size}\n"
|
|||
|
f"num_epochs: {num_epochs}\n"
|
|||
|
f"learning_rate: {learning_rate}\n"
|
|||
|
f"cutoff_len: {cutoff_len}\n"
|
|||
|
f"val_set_size: {val_set_size}\n"
|
|||
|
f"train_on_inputs: {train_on_inputs}\n"
|
|||
|
f"add_eos_token: {add_eos_token}\n"
|
|||
|
f"group_by_length: {group_by_length}\n"
|
|||
|
f"wandb_project: {wandb_project}\n"
|
|||
|
f"wandb_run_name: {wandb_run_name}\n"
|
|||
|
f"wandb_watch: {wandb_watch}\n"
|
|||
|
f"wandb_log_model: {wandb_log_model}\n"
|
|||
|
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
|||
|
f"prompt template: {prompt_template_name}\n"
|
|||
|
)
|
|||
|
assert (
|
|||
|
base_model
|
|||
|
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
|||
|
gradient_accumulation_steps = batch_size // micro_batch_size
|
|||
|
|
|||
|
prompter = Prompter(prompt_template_name)
|
|||
|
|
|||
|
device_map = "auto"
|
|||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|||
|
ddp = world_size != 1
|
|||
|
if ddp:
|
|||
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
|||
|
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
|||
|
|
|||
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
base_model,
|
|||
|
load_in_8bit=True,
|
|||
|
# 使用Auto类自动选择正确的模型类型
|
|||
|
torch_dtype=torch.float16,
|
|||
|
device_map=device_map,
|
|||
|
trust_remote_code=True, # Qwen模型需要此参数
|
|||
|
)
|
|||
|
|
|||
|
tokenizer = AutoTokenizer.from_pretrained(
|
|||
|
base_model,
|
|||
|
trust_remote_code=True, # 添加此参数
|
|||
|
padding_side="left", # Qwen也推荐左侧填充
|
|||
|
)
|
|||
|
tokenizer.pad_token = tokenizer.eos_token
|
|||
|
|
|||
|
|
|||
|
|
|||
|
# tokenizer.pad_token_id = (
|
|||
|
# 0 # unk. we want this to be different from the eos token
|
|||
|
# )
|
|||
|
# tokenizer.padding_side = "left" # Allow batched inference
|
|||
|
# model.gradient_checkpointing_enable()
|
|||
|
# tokenizer.pad_token = tokenizer.eos_token
|
|||
|
model.config.pad_token_id = model.config.eos_token_id
|
|||
|
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
|||
|
|
|||
|
def ensure_consistent_keys(dataset):
|
|||
|
all_keys = set()
|
|||
|
for example in dataset:
|
|||
|
all_keys.update(example.keys())
|
|||
|
|
|||
|
for example in dataset:
|
|||
|
for key in all_keys:
|
|||
|
if key not in example:
|
|||
|
if key == "static_prefix":
|
|||
|
example[key] = ""
|
|||
|
elif key == "sensor_data":
|
|||
|
example[key] = [0, 0, 0]
|
|||
|
|
|||
|
return dataset
|
|||
|
|
|||
|
# def tokenize(prompt, add_eos_token=True):
|
|||
|
# # there's probably a way to do this with the tokenizer settings
|
|||
|
# # but again, gotta move fast
|
|||
|
# result = tokenizer(
|
|||
|
# prompt,
|
|||
|
# truncation=True,
|
|||
|
# max_length=cutoff_len,
|
|||
|
# padding=False,
|
|||
|
# return_tensors=None,
|
|||
|
# )
|
|||
|
# if (
|
|||
|
# result["input_ids"][-1] != tokenizer.eos_token_id
|
|||
|
# and len(result["input_ids"]) < cutoff_len
|
|||
|
# and add_eos_token
|
|||
|
# ):
|
|||
|
# result["input_ids"].append(tokenizer.eos_token_id)
|
|||
|
# result["attention_mask"].append(1)
|
|||
|
#
|
|||
|
# result["labels"] = result["input_ids"].copy()
|
|||
|
#
|
|||
|
# return result
|
|||
|
|
|||
|
def generate_and_tokenize_prompt(data_point):
|
|||
|
full_prompt = prompter.generate_prompt(
|
|||
|
data_point["instruction"],
|
|||
|
data_point["input"],
|
|||
|
data_point["output"],
|
|||
|
)
|
|||
|
|
|||
|
# Tokenizer 处理文本
|
|||
|
tokenized_full_prompt = tokenizer(
|
|||
|
full_prompt,
|
|||
|
truncation=True,
|
|||
|
max_length=cutoff_len,
|
|||
|
padding=True,
|
|||
|
return_tensors='pt',
|
|||
|
)
|
|||
|
# for k,v in tokenized_full_prompt.items(): print("======k,v",k,v,type(k),type(v))
|
|||
|
|
|||
|
# exit(0)
|
|||
|
|
|||
|
|
|||
|
tokenized_full_prompt = {k: v.squeeze(0) for k, v in tokenized_full_prompt.items()}
|
|||
|
|
|||
|
# 处理静态前缀
|
|||
|
static_prefix = tokenizer(
|
|||
|
data_point["instruction"],
|
|||
|
truncation=True,
|
|||
|
max_length=10,
|
|||
|
padding="max_length",
|
|||
|
return_tensors="pt"
|
|||
|
)["input_ids"].squeeze(0)
|
|||
|
|
|||
|
# 限制索引范围,确保 `static_prefix` 不会超出 `vocab_size`
|
|||
|
static_prefix = torch.clamp(static_prefix, min=0, max=tokenizer.vocab_size - 1)
|
|||
|
|
|||
|
tokenized_full_prompt["static_prefix"] = static_prefix
|
|||
|
# print(f"[DEBUG] static_prefix (after clamp): {static_prefix}")
|
|||
|
# print(f"[DEBUG] tokenizer vocab_size: {tokenizer.vocab_size}")
|
|||
|
|
|||
|
# **处理动态数据**
|
|||
|
sensor_values = torch.zeros(3, dtype=torch.float) # **默认值为 Tensor,而不是 list**
|
|||
|
|
|||
|
if data_point["type"] == "dynamic" and "sensor_data" in data_point:
|
|||
|
raw_sensor_values = data_point["sensor_data"]
|
|||
|
|
|||
|
try:
|
|||
|
sensor_values = torch.tensor([
|
|||
|
float(raw_sensor_values.get("temperature", 0.0)),
|
|||
|
float(raw_sensor_values.get("humidity", 0.0)),
|
|||
|
float(raw_sensor_values.get("conductivity", 0.0))
|
|||
|
], dtype=torch.float)
|
|||
|
except Exception as e:
|
|||
|
# print(f"[ERROR] sensor_data 解析错误: {raw_sensor_values}, {e}")
|
|||
|
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
|
|||
|
# print(f"[ERROR] NaN/Inf detected in sensor_values: {sensor_values}")
|
|||
|
sensor_values = torch.zeros(3, dtype=torch.float)
|
|||
|
|
|||
|
# ✅ 确保 sensor_values 是 `Tensor`
|
|||
|
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
|
|||
|
print(f"[ERROR] NaN/Inf detected in sensor_values")
|
|||
|
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
|
|||
|
print(f"[ERROR] NaN/Inf detected in sensor_values")
|
|||
|
sensor_values = torch.zeros(3, dtype=torch.float)
|
|||
|
|
|||
|
# 限制范围,防止异常值
|
|||
|
sensor_values = torch.clamp(sensor_values, min=-100, max=100)
|
|||
|
|
|||
|
# print(f"[DEBUG] sensor_values (AFTER FIX): {sensor_values}") # 🔥 打印调试信息
|
|||
|
if not isinstance(sensor_values, torch.Tensor):
|
|||
|
sensor_values = torch.tensor(sensor_values, dtype=torch.float)
|
|||
|
|
|||
|
tokenized_full_prompt["sensor_data"] = sensor_values # **确保始终是 Tensor**
|
|||
|
|
|||
|
# 最后增加类型检查和转换
|
|||
|
for key in tokenized_full_prompt:
|
|||
|
if isinstance(tokenized_full_prompt[key], list):
|
|||
|
# Convert lists to tensors
|
|||
|
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key])
|
|||
|
elif isinstance(tokenized_full_prompt[key], torch.Tensor) and tokenized_full_prompt[key].dim() > 1:
|
|||
|
# Squeeze extra dimensions if needed
|
|||
|
tokenized_full_prompt[key] = tokenized_full_prompt[key].squeeze(0)
|
|||
|
|
|||
|
if key in ["input_ids", "attention_mask"] and isinstance(tokenized_full_prompt[key], list):
|
|||
|
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key], dtype=torch.long)
|
|||
|
|
|||
|
if isinstance(tokenized_full_prompt["static_prefix"], list):
|
|||
|
tokenized_full_prompt["static_prefix"] = torch.tensor(tokenized_full_prompt["static_prefix"],
|
|||
|
dtype=torch.long)
|
|||
|
|
|||
|
# 确保sensor_data是tensor
|
|||
|
if not isinstance(tokenized_full_prompt["sensor_data"], torch.Tensor):
|
|||
|
tokenized_full_prompt["sensor_data"] = torch.tensor(tokenized_full_prompt["sensor_data"], dtype=torch.float)
|
|||
|
|
|||
|
tokenized_full_prompt["labels"] = tokenized_full_prompt["input_ids"].clone()
|
|||
|
|
|||
|
# 如果不想对输入部分计算损失,可以将输入部分的标签设为-100
|
|||
|
if not train_on_inputs:
|
|||
|
# 找到用户输入和助手输出的分界点
|
|||
|
sep = tokenizer.encode(prompter.separator)
|
|||
|
instruction_tokens = tokenizer.encode(data_point["instruction"])
|
|||
|
|
|||
|
# 将用户输入部分的标签设为-100
|
|||
|
sep_pos = tokenized_full_prompt["input_ids"].tolist().index(sep[0])
|
|||
|
tokenized_full_prompt["labels"][:sep_pos] = -100
|
|||
|
|
|||
|
return tokenized_full_prompt
|
|||
|
|
|||
|
|
|||
|
# 创建PrefixTuning配置
|
|||
|
|
|||
|
prefix_config = PrefixTuningConfig(
|
|||
|
num_virtual_tokens=num_prefix,
|
|||
|
task_type="CAUSAL_LM"
|
|||
|
)
|
|||
|
|
|||
|
# 创建PEFT模型
|
|||
|
peft_model = get_peft_model(model, prefix_config)
|
|||
|
|
|||
|
|
|||
|
# 创建最终的KoPAWithAdapter模型
|
|||
|
final_model = KoPAWithAdapter(peft_model, num_prefix, tokenizer)
|
|||
|
device = next(model.parameters()).device
|
|||
|
print(f"[INFO] 使用设备: {device}")
|
|||
|
|
|||
|
# 确保final_model及其组件都在相同设备上
|
|||
|
final_model = final_model.to(device)
|
|||
|
|
|||
|
|
|||
|
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
|
|||
|
data = load_dataset("json", data_files=data_path)
|
|||
|
else:
|
|||
|
data = load_dataset(data_path)
|
|||
|
|
|||
|
if resume_from_checkpoint:
|
|||
|
# Check the available weights and load them
|
|||
|
checkpoint_name = os.path.join(
|
|||
|
resume_from_checkpoint, "pytorch_model.bin"
|
|||
|
) # Full checkpoint
|
|||
|
if not os.path.exists(checkpoint_name):
|
|||
|
checkpoint_name = os.path.join(
|
|||
|
resume_from_checkpoint, "adapter_model.bin"
|
|||
|
) # only LoRA model - LoRA config above has to fit
|
|||
|
resume_from_checkpoint = (
|
|||
|
False # So the trainer won't try loading its state
|
|||
|
)
|
|||
|
# The two files above have a different name depending on how they were saved, but are actually the same.
|
|||
|
if os.path.exists(checkpoint_name):
|
|||
|
print(f"Restarting from {checkpoint_name}")
|
|||
|
adapters_weights = torch.load(checkpoint_name)
|
|||
|
else:
|
|||
|
print(f"Checkpoint {checkpoint_name} not found")
|
|||
|
|
|||
|
# model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
|||
|
|
|||
|
if val_set_size > 0:
|
|||
|
train_val = data["train"].train_test_split(
|
|||
|
test_size=val_set_size, shuffle=True, seed=42
|
|||
|
)
|
|||
|
train_data = (
|
|||
|
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
|||
|
|
|||
|
)
|
|||
|
train_data = ensure_consistent_keys(train_data)
|
|||
|
val_data = (
|
|||
|
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
|||
|
)
|
|||
|
else:
|
|||
|
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
|||
|
train_data = ensure_consistent_keys(train_data)
|
|||
|
val_data = None
|
|||
|
|
|||
|
if not ddp and torch.cuda.device_count() > 1:
|
|||
|
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
|||
|
model.is_parallelizable = True
|
|||
|
model.model_parallel = True
|
|||
|
|
|||
|
trainer = transformers.Trainer(
|
|||
|
model=final_model,
|
|||
|
data_collator=custom_collate_fn,
|
|||
|
train_dataset=train_data,
|
|||
|
eval_dataset=val_data,
|
|||
|
args=transformers.TrainingArguments(
|
|||
|
per_device_train_batch_size=micro_batch_size,
|
|||
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
|||
|
warmup_steps=100,
|
|||
|
num_train_epochs=num_epochs,
|
|||
|
learning_rate=learning_rate,
|
|||
|
fp16=True,
|
|||
|
logging_steps=10,
|
|||
|
optim="adamw_hf",
|
|||
|
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
|||
|
save_strategy="steps",
|
|||
|
eval_steps=None,
|
|||
|
save_steps=5000,
|
|||
|
output_dir=output_dir,
|
|||
|
save_total_limit=2,
|
|||
|
load_best_model_at_end=True if val_set_size > 0 else False,
|
|||
|
ddp_find_unused_parameters=False if ddp else None,
|
|||
|
group_by_length=group_by_length,
|
|||
|
report_to=None,
|
|||
|
run_name=None,
|
|||
|
),
|
|||
|
)
|
|||
|
# final_model.config.use_cache = False
|
|||
|
|
|||
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
|||
|
final_model = torch.compile(model)
|
|||
|
|
|||
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|||
|
|
|||
|
final_model.save_pretrained(output_dir)
|
|||
|
|
|||
|
# ⭐ 确保embeddings存在再保存
|
|||
|
if hasattr(final_model, "embeddings"):
|
|||
|
torch.save(final_model.embeddings, os.path.join(output_dir, "embeddings.pth"))
|
|||
|
else:
|
|||
|
print("[WARNING] final_model没有embeddings属性,跳过保存。")
|
|||
|
|
|||
|
try:
|
|||
|
final_model.model.save_pretrained(os.path.join(output_dir, "peft_model"))
|
|||
|
print(f"[INFO] PEFT模型保存到 {os.path.join(output_dir, 'peft_model')}")
|
|||
|
except Exception as e:
|
|||
|
print(f"[WARNING] 保存PEFT模型时出错: {e}")
|
|||
|
|
|||
|
def inspect_model_structure(model):
|
|||
|
"""检查模型结构并打印关键层信息"""
|
|||
|
print(f"Model type: {type(model).__name__}")
|
|||
|
print(f"Model config: {model.config.__class__.__name__}")
|
|||
|
|
|||
|
# 检查嵌入层
|
|||
|
embedding_layers = []
|
|||
|
for name, module in model.named_modules():
|
|||
|
if any(key in name for key in ['embed', 'wte', 'word_embeddings']):
|
|||
|
embedding_layers.append((name, type(module).__name__))
|
|||
|
if hasattr(module, 'weight'):
|
|||
|
print(f"Layer {name}: shape {module.weight.shape}")
|
|||
|
|
|||
|
print(f"Found {len(embedding_layers)} potential embedding layers:")
|
|||
|
for name, type_name in embedding_layers:
|
|||
|
print(f" - {name}: {type_name}")
|
|||
|
|
|||
|
# 检查注意力层
|
|||
|
print("\nAttention structure:")
|
|||
|
for name, module in model.named_modules():
|
|||
|
if 'attention' in name.lower():
|
|||
|
print(f" - {name}: {type(module).__name__}")
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
fire.Fire(train)
|