Qwen适配

This commit is contained in:
黄子寒 2025-03-17 20:17:41 +08:00
parent 1b3dd9475c
commit 4f0be21236
18 changed files with 1210 additions and 44 deletions

View File

@ -0,0 +1,273 @@
import os
import sys
from typing import List
import time
import fire
import torch
import transformers
from datasets import load_dataset
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter
def train(
# model/data params
base_model: str = "", # the only required argument
data_path: str = "YOUR LLM PATH",
output_dir: str = "./lora-alpaca",
# training hyperparams
batch_size: int = 16,
micro_batch_size: int = 16,
num_epochs: int = 2,
learning_rate: float = 3e-4,
cutoff_len: int = 512,
val_set_size: int = 0,
# lora hyperparams
lora_r: int = 16,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj",
"v_proj",
],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
model = LlamaForCausalLM.from_pretrained(
base_model,
# load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=None,
save_steps=8000,
output_dir=output_dir,
save_total_limit=2,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to=None,
run_name=None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print(
"\n If there's a warning about missing keys above, please disregard :)"
)
if __name__ == "__main__":
fire.Fire(train)

View File

@ -0,0 +1,476 @@
import os
import sys
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
from kopa import KoPAWithAdapter
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""
from peft import PrefixTuningConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
def custom_collate_fn(batch):
input_ids_list = []
attention_mask_list = []
static_prefix_list = []
sensor_data_list = []
# qwen_dict= {'llama_eos_tid':, 'qwen_eos_tid':}
for b in batch:
# 确保输入是张量
if isinstance(b["input_ids"], list):
input_ids = torch.tensor(b["input_ids"], dtype=torch.long)
else:
input_ids = b["input_ids"]
input_ids_list.append(input_ids)
if isinstance(b["attention_mask"], list):
attention_mask = torch.tensor(b["attention_mask"], dtype=torch.long)
else:
attention_mask = b["attention_mask"]
attention_mask_list.append(attention_mask)
if "static_prefix" in b:
if isinstance(b["static_prefix"], list):
static_prefix = torch.tensor(b["static_prefix"], dtype=torch.long)
else:
static_prefix = b["static_prefix"]
static_prefix_list.append(static_prefix)
if "sensor_data" in b:
if isinstance(b["sensor_data"], list):
sensor_data = torch.tensor(b["sensor_data"], dtype=torch.float)
else:
sensor_data = b["sensor_data"]
sensor_data_list.append(sensor_data)
max_length=0
for one_inputs in input_ids_list:
max_length = one_inputs.size(0) if max_length < one_inputs.size(0) else max_length
input_ids_list_=list()
for one_inputs in input_ids_list:
input_ids_list_.append(torch.cat((one_inputs, torch.full((max_length-one_inputs.size(0),), 0, dtype=torch.int)), dim=-1))
attention_mask_list_=list()
for mask in attention_mask_list:
attention_mask_list_.append(torch.cat((mask, torch.full((max_length-mask.size(0),), 0, dtype=torch.int)), dim=-1))
# print("=====",input_ids_list)
# exit(0)
# 堆叠数据
result = {
"input_ids": torch.stack(input_ids_list_),
"attention_mask": torch.stack(attention_mask_list_),
}
if static_prefix_list:
result["static_prefix"] = torch.stack(static_prefix_list)
if sensor_data_list:
result["sensor_data"] = torch.stack(sensor_data_list)
if "labels" in batch[0]:
labels_list = []
for b in batch:
if isinstance(b["labels"], list):
labels = torch.tensor(b["labels"], dtype=torch.long)
else:
labels = b["labels"]
labels_list.append(labels)
labels_list_=list()
for label in labels_list:
labels_list_.append(torch.cat((label, torch.full((max_length-label.size(0),), 0, dtype=torch.int)), dim=-1))
result["labels"] = torch.stack(labels_list_)
return result
def train(
# model/data params
base_model="/root/shared-nvme/models/Qwen2.5-7B-Instruct",
data_path: str = "/root/shared-nvme/dataset/olive_dataset.json",
output_dir: str = "output",
# training hyperparams
batch_size: int = 16,
micro_batch_size: int = 16,
num_epochs: int = 2,
learning_rate: float = 1e-4,
cutoff_len: int = 512,
val_set_size: int = 0,
num_prefix: int = 1,
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
# 使用Auto类自动选择正确的模型类型
torch_dtype=torch.float16,
device_map=device_map,
trust_remote_code=True, # Qwen模型需要此参数
)
tokenizer = AutoTokenizer.from_pretrained(
base_model,
trust_remote_code=True, # 添加此参数
padding_side="left", # Qwen也推荐左侧填充
)
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.pad_token_id = (
# 0 # unk. we want this to be different from the eos token
# )
# tokenizer.padding_side = "left" # Allow batched inference
# model.gradient_checkpointing_enable()
# tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.generation_config.pad_token_id = model.generation_config.eos_token_id
def ensure_consistent_keys(dataset):
all_keys = set()
for example in dataset:
all_keys.update(example.keys())
for example in dataset:
for key in all_keys:
if key not in example:
if key == "static_prefix":
example[key] = ""
elif key == "sensor_data":
example[key] = [0, 0, 0]
return dataset
# def tokenize(prompt, add_eos_token=True):
# # there's probably a way to do this with the tokenizer settings
# # but again, gotta move fast
# result = tokenizer(
# prompt,
# truncation=True,
# max_length=cutoff_len,
# padding=False,
# return_tensors=None,
# )
# if (
# result["input_ids"][-1] != tokenizer.eos_token_id
# and len(result["input_ids"]) < cutoff_len
# and add_eos_token
# ):
# result["input_ids"].append(tokenizer.eos_token_id)
# result["attention_mask"].append(1)
#
# result["labels"] = result["input_ids"].copy()
#
# return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
# Tokenizer 处理文本
tokenized_full_prompt = tokenizer(
full_prompt,
truncation=True,
max_length=cutoff_len,
padding=True,
return_tensors='pt',
)
# for k,v in tokenized_full_prompt.items(): print("======k,v",k,v,type(k),type(v))
# exit(0)
tokenized_full_prompt = {k: v.squeeze(0) for k, v in tokenized_full_prompt.items()}
# 处理静态前缀
static_prefix = tokenizer(
data_point["instruction"],
truncation=True,
max_length=10,
padding="max_length",
return_tensors="pt"
)["input_ids"].squeeze(0)
# 限制索引范围,确保 `static_prefix` 不会超出 `vocab_size`
static_prefix = torch.clamp(static_prefix, min=0, max=tokenizer.vocab_size - 1)
tokenized_full_prompt["static_prefix"] = static_prefix
# print(f"[DEBUG] static_prefix (after clamp): {static_prefix}")
# print(f"[DEBUG] tokenizer vocab_size: {tokenizer.vocab_size}")
# **处理动态数据**
sensor_values = torch.zeros(3, dtype=torch.float) # **默认值为 Tensor而不是 list**
if data_point["type"] == "dynamic" and "sensor_data" in data_point:
raw_sensor_values = data_point["sensor_data"]
try:
sensor_values = torch.tensor([
float(raw_sensor_values.get("temperature", 0.0)),
float(raw_sensor_values.get("humidity", 0.0)),
float(raw_sensor_values.get("conductivity", 0.0))
], dtype=torch.float)
except Exception as e:
# print(f"[ERROR] sensor_data 解析错误: {raw_sensor_values}, {e}")
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
# print(f"[ERROR] NaN/Inf detected in sensor_values: {sensor_values}")
sensor_values = torch.zeros(3, dtype=torch.float)
# ✅ 确保 sensor_values 是 `Tensor`
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
print(f"[ERROR] NaN/Inf detected in sensor_values")
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
print(f"[ERROR] NaN/Inf detected in sensor_values")
sensor_values = torch.zeros(3, dtype=torch.float)
# 限制范围,防止异常值
sensor_values = torch.clamp(sensor_values, min=-100, max=100)
# print(f"[DEBUG] sensor_values (AFTER FIX): {sensor_values}") # 🔥 打印调试信息
if not isinstance(sensor_values, torch.Tensor):
sensor_values = torch.tensor(sensor_values, dtype=torch.float)
tokenized_full_prompt["sensor_data"] = sensor_values # **确保始终是 Tensor**
# 最后增加类型检查和转换
for key in tokenized_full_prompt:
if isinstance(tokenized_full_prompt[key], list):
# Convert lists to tensors
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key])
elif isinstance(tokenized_full_prompt[key], torch.Tensor) and tokenized_full_prompt[key].dim() > 1:
# Squeeze extra dimensions if needed
tokenized_full_prompt[key] = tokenized_full_prompt[key].squeeze(0)
if key in ["input_ids", "attention_mask"] and isinstance(tokenized_full_prompt[key], list):
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key], dtype=torch.long)
if isinstance(tokenized_full_prompt["static_prefix"], list):
tokenized_full_prompt["static_prefix"] = torch.tensor(tokenized_full_prompt["static_prefix"],
dtype=torch.long)
# 确保sensor_data是tensor
if not isinstance(tokenized_full_prompt["sensor_data"], torch.Tensor):
tokenized_full_prompt["sensor_data"] = torch.tensor(tokenized_full_prompt["sensor_data"], dtype=torch.float)
tokenized_full_prompt["labels"] = tokenized_full_prompt["input_ids"].clone()
# 如果不想对输入部分计算损失,可以将输入部分的标签设为-100
if not train_on_inputs:
# 找到用户输入和助手输出的分界点
sep = tokenizer.encode(prompter.separator)
instruction_tokens = tokenizer.encode(data_point["instruction"])
# 将用户输入部分的标签设为-100
sep_pos = tokenized_full_prompt["input_ids"].tolist().index(sep[0])
tokenized_full_prompt["labels"][:sep_pos] = -100
return tokenized_full_prompt
# 创建PrefixTuning配置
prefix_config = PrefixTuningConfig(
num_virtual_tokens=num_prefix,
task_type="CAUSAL_LM"
)
# 创建PEFT模型
peft_model = get_peft_model(model, prefix_config)
# 创建最终的KoPAWithAdapter模型
final_model = KoPAWithAdapter(peft_model, num_prefix, tokenizer)
device = next(model.parameters()).device
print(f"[INFO] 使用设备: {device}")
# 确保final_model及其组件都在相同设备上
final_model = final_model.to(device)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
else:
print(f"Checkpoint {checkpoint_name} not found")
# model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
train_data = ensure_consistent_keys(train_data)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
train_data = ensure_consistent_keys(train_data)
val_data = None
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=final_model,
data_collator=custom_collate_fn,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
optim="adamw_hf",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=None,
save_steps=5000,
output_dir=output_dir,
save_total_limit=2,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to=None,
run_name=None,
),
)
# final_model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
final_model = torch.compile(model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
final_model.save_pretrained(output_dir)
# ⭐ 确保embeddings存在再保存
if hasattr(final_model, "embeddings"):
torch.save(final_model.embeddings, os.path.join(output_dir, "embeddings.pth"))
else:
print("[WARNING] final_model没有embeddings属性跳过保存。")
try:
final_model.model.save_pretrained(os.path.join(output_dir, "peft_model"))
print(f"[INFO] PEFT模型保存到 {os.path.join(output_dir, 'peft_model')}")
except Exception as e:
print(f"[WARNING] 保存PEFT模型时出错: {e}")
def inspect_model_structure(model):
"""检查模型结构并打印关键层信息"""
print(f"Model type: {type(model).__name__}")
print(f"Model config: {model.config.__class__.__name__}")
# 检查嵌入层
embedding_layers = []
for name, module in model.named_modules():
if any(key in name for key in ['embed', 'wte', 'word_embeddings']):
embedding_layers.append((name, type(module).__name__))
if hasattr(module, 'weight'):
print(f"Layer {name}: shape {module.weight.shape}")
print(f"Found {len(embedding_layers)} potential embedding layers:")
for name, type_name in embedding_layers:
print(f" - {name}: {type_name}")
# 检查注意力层
print("\nAttention structure:")
for name, module in model.named_modules():
if 'attention' in name.lower():
print(f" - {name}: {type(module).__name__}")
if __name__ == "__main__":
fire.Fire(train)

View File

@ -0,0 +1,297 @@
import torch
import torch.nn as nn
from typing import Optional, List, Union, Tuple
from transformers import LlamaForCausalLM
class KoPA(nn.Module):
def __init__(
self,
model
) -> None:
super(KoPA, self).__init__()
self.llama_model = model
for param in self.model.parameters():
param.requires_grad = False
# Only keep gradients for the adapter parts
self.num_prefix = num_prefix
hidden_size = model.config.hidden_size
self.embeddings = nn.Embedding(100, 4096)
for param in model.parameters():
param.requires_grad = False
# Only enable gradients for adapter components
self.static_prefix_embedding.requires_grad_(True)
self.sensor_mlp.requires_grad_(True)
self.norm.requires_grad_(True)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
embedding_ids: torch.LongTensor = None
):
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
kg_embeds = self.embeddings(embedding_ids)
batch_size, seq_len, _ = kg_embeds.shape
if hasattr(self.llama_model, 'transformer'):
# Qwen模型
token_embeds = self.llama_model.transformer.wte(input_ids)
elif hasattr(self.llama_model, 'model') and hasattr(self.llama_model.model, 'embed_tokens'):
# 原始路径
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
else:
# 添加调试代码
print("无法找到模型嵌入层,尝试检测模型结构...")
raise ValueError("模型结构不兼容")
input_embeds = torch.cat((kg_embeds, token_embeds), dim=1)
prefix_mask = torch.ones((batch_size, seq_len))
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1)
new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1)
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
return self.llama_model(
input_ids=None,
attention_mask=new_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=input_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class KoPAWithAdapter(nn.Module):
def __init__(self, model, num_prefix, tokenizer=None):
super().__init__()
self.model = model
self.num_prefix = num_prefix
hidden_size = model.config.hidden_size
# 打印模型信息以便调试
print(f"[INFO] 初始化KoPAWithAdapter模型类型: {type(model).__name__}")
# 使用tokenizer获取vocab_size
vocab_size = tokenizer.vocab_size if tokenizer else 151936 # Qwen2.5的默认词表大小
print(f"[INFO] 使用词表大小: {vocab_size}")
self.static_prefix_embedding = nn.Embedding(vocab_size, hidden_size)
self.embeddings = self.static_prefix_embedding # 保留这个属性
self.sensor_mlp = nn.Sequential(
nn.Linear(3, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, hidden_size)
)
# 添加LayerNorm
self.norm = nn.LayerNorm(hidden_size)
print(f"[INFO] 模型初始化: hidden_size={hidden_size}, vocab_size={vocab_size}")
# 检测模型嵌入层路径
self._detect_embedding_path()
def _detect_embedding_path(self):
"""检测模型的嵌入层路径"""
self.embedding_path = None
# 尝试不同的常见路径
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'):
self.embedding_path = "transformer.wte"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'):
self.embedding_path = "model.embed_tokens"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'):
self.embedding_path = "model.model.model.embed_tokens"
if self.embedding_path:
print(f"[INFO] 检测到嵌入层路径: {self.embedding_path}")
else:
print("[WARNING] 无法自动检测嵌入层路径,将在前向传播中尝试多种路径")
def forward(self, input_ids, attention_mask, static_prefix=None, sensor_data=None, labels=None, **kwargs):
batch_size, seq_len = input_ids.shape
device = input_ids.device
# 确保所有组件在同一设备上
self.static_prefix_embedding = self.static_prefix_embedding.to(device)
self.sensor_mlp = self.sensor_mlp.to(device)
self.norm = self.norm.to(device)
# 处理静态前缀
if static_prefix is not None:
static_prefix = static_prefix.to(device)
static_prefix = self.static_prefix_embedding(static_prefix)
else:
static_prefix = torch.zeros(
(batch_size, self.num_prefix, self.model.config.hidden_size),
device=device
)
# 处理动态前缀
if sensor_data is not None:
sensor_data = sensor_data.to(device)
if sensor_data.dim() == 1:
sensor_data = sensor_data.unsqueeze(0)
try:
dynamic_prefix = self.sensor_mlp(sensor_data)
dynamic_prefix = dynamic_prefix.unsqueeze(1).expand(-1, self.num_prefix, -1)
except Exception as e:
print(f"[ERROR] sensor_mlp处理失败: {e}")
dynamic_prefix = torch.zeros_like(static_prefix)
else:
dynamic_prefix = torch.zeros_like(static_prefix)
# 混合前缀
alpha = 0.6
final_prefix = alpha * static_prefix + (1 - alpha) * dynamic_prefix
final_prefix = self.norm(final_prefix)
# 处理token嵌入 - 根据检测到的路径获取嵌入
try:
if self.embedding_path == "transformer.wte":
token_embeds = self.model.transformer.wte(input_ids)
elif self.embedding_path == "model.embed_tokens":
token_embeds = self.model.model.embed_tokens(input_ids)
elif self.embedding_path == "model.model.model.embed_tokens":
token_embeds = self.model.model.model.embed_tokens(input_ids)
else:
# 尝试多种可能的路径
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'):
token_embeds = self.model.transformer.wte(input_ids)
self.embedding_path = "transformer.wte"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'):
token_embeds = self.model.model.embed_tokens(input_ids)
self.embedding_path = "model.embed_tokens"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'):
token_embeds = self.model.model.model.embed_tokens(input_ids)
self.embedding_path = "model.model.model.embed_tokens"
else:
raise ValueError("无法找到嵌入层路径")
print(f"[INFO] 成功找到嵌入层路径: {self.embedding_path}")
except Exception as e:
print(f"[ERROR] 获取token嵌入失败: {e}")
# 打印模型结构以帮助调试
print("模型结构:")
for name, _ in self.model.named_modules():
if 'embed' in name or 'wte' in name:
print(f" - {name}")
raise
input_embeds = torch.cat((final_prefix, token_embeds), dim=1)
# 扩展注意力掩码
prefix_attention_mask = torch.ones(
(batch_size, self.num_prefix),
dtype=attention_mask.dtype,
device=device
)
extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
# 处理标签
if labels is not None:
# 为前缀部分创建-100的标签表示忽略
prefix_labels = torch.full(
(batch_size, self.num_prefix),
fill_value=-100, # -100表示忽略这些位置的损失
dtype=labels.dtype,
device=device
)
# 扩展标签
extended_labels = torch.cat((prefix_labels, labels), dim=1)
else:
extended_labels = None
# 确保不提供input_ids
if 'input_ids' in kwargs:
del kwargs['input_ids']
# 传递扩展后的标签
return self.model(
inputs_embeds=input_embeds,
attention_mask=extended_attention_mask,
labels=extended_labels,
use_cache=False,
**kwargs)
# class PrefixKGEmbedding(nn.Module):
# def __init__(
# self,
# num_ent,
# num_rel,
# dim_llm,
# num_prefix
# ):
# super(PrefixKGEmbedding, self).__init__()
# self.emb_dim = num_prefix * dim_llm
# self.ent_embeddings = nn.Embedding(num_ent, self.emb_dim)
# self.rel_embeddings = nn.Embedding(num_rel, self.emb_dim)
#
#
# def forward(self, triple_ids):
# head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
# h = self.ent_embeddings(head)
# r = self.rel_embeddings(relation)
# t = self.ent_embeddings(tail)
# prefix = torch.stack((h, r, t), dim=1)
# return prefix
class PretrainKGEmbedding(nn.Module):
def __init__(
self,
pretrain_ent_embs,
pretrain_rel_embs,
dim_llm,
num_prefix
):
super(PretrainKGEmbedding, self).__init__()
self.num_prefix = num_prefix
self.llm_dim = dim_llm
self.emb_dim = num_prefix * dim_llm
self.ent_embeddings = nn.Embedding.from_pretrained(pretrain_ent_embs)
self.rel_embeddings = nn.Embedding.from_pretrained(pretrain_rel_embs)
self.pretrain_dim = self.ent_embeddings.weight.shape[1]
# Froze the pretrain embeddings
self.ent_embeddings.requires_grad_(False)
self.rel_embeddings.requires_grad_(False)
self.adapter = nn.Linear(self.pretrain_dim, self.emb_dim)
def forward(self, triple_ids):
# main training stage
if triple_ids.shape[1] == 3:
head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
h = self.ent_embeddings(head)
r = self.rel_embeddings(relation)
t = self.ent_embeddings(tail)
pretrain_embs = torch.stack((h, r, t), dim=1)
prefix = self.adapter(pretrain_embs).reshape(-1, 3*self.num_prefix, self.llm_dim)
return prefix
# entity-aware pre-funing
else:
ent = triple_ids.reshape(-1,)
emb = self.ent_embeddings(ent)
prefix = self.adapter(emb).reshape(-1, self.num_prefix, self.llm_dim)
# print(prefix.shape)
return prefix

Binary file not shown.

View File

@ -16,7 +16,7 @@ import bitsandbytes as bnb
"""
from peft import PrefixTuningConfig, get_peft_model
from transformers import LlamaForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
@ -26,6 +26,7 @@ def custom_collate_fn(batch):
attention_mask_list = []
static_prefix_list = []
sensor_data_list = []
# qwen_dict= {'llama_eos_tid':, 'qwen_eos_tid':}
for b in batch:
# 确保输入是张量
@ -54,11 +55,25 @@ def custom_collate_fn(batch):
else:
sensor_data = b["sensor_data"]
sensor_data_list.append(sensor_data)
max_length=0
for one_inputs in input_ids_list:
max_length = one_inputs.size(0) if max_length < one_inputs.size(0) else max_length
input_ids_list_=list()
for one_inputs in input_ids_list:
input_ids_list_.append(torch.cat((one_inputs, torch.full((max_length-one_inputs.size(0),), 0, dtype=torch.int)), dim=-1))
attention_mask_list_=list()
for mask in attention_mask_list:
attention_mask_list_.append(torch.cat((mask, torch.full((max_length-mask.size(0),), 0, dtype=torch.int)), dim=-1))
# print("=====",input_ids_list)
# exit(0)
# 堆叠数据
result = {
"input_ids": torch.stack(input_ids_list),
"attention_mask": torch.stack(attention_mask_list),
"input_ids": torch.stack(input_ids_list_),
"attention_mask": torch.stack(attention_mask_list_),
}
if static_prefix_list:
@ -75,22 +90,26 @@ def custom_collate_fn(batch):
else:
labels = b["labels"]
labels_list.append(labels)
labels_list_=list()
for label in labels_list:
labels_list_.append(torch.cat((label, torch.full((max_length-label.size(0),), 0, dtype=torch.int)), dim=-1))
result["labels"] = torch.stack(labels_list)
result["labels"] = torch.stack(labels_list_)
return result
def train(
# model/data params
base_model="models/Llama-3.2-3B-Instruct",
data_path: str = "data/CoDeX-S-train.json",
base_model="/root/shared-nvme/models/Qwen2.5-7B-Instruct",
data_path: str = "/root/shared-nvme/dataset/olive_dataset.json",
output_dir: str = "output",
# training hyperparams
batch_size: int = 16,
micro_batch_size: int = 16,
num_epochs: int = 2,
learning_rate: float = 3e-4,
learning_rate: float = 1e-4,
cutoff_len: int = 512,
val_set_size: int = 0,
num_prefix: int = 1,
@ -142,25 +161,30 @@ def train(
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
model = LlamaForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
# load_in_8bit=True,
load_in_8bit=True,
# 使用Auto类自动选择正确的模型类型
torch_dtype=torch.float16,
device_map=device_map,
trust_remote_code=True, # Qwen模型需要此参数
)
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer = AutoTokenizer.from_pretrained(
base_model,
trust_remote_code=True, # 添加此参数
padding_side="left", # Qwen也推荐左侧填充
)
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.pad_token_id = (
# 0 # unk. we want this to be different from the eos token
# )
tokenizer.padding_side = "left" # Allow batched inference
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left" # Allow batched inference
# model.gradient_checkpointing_enable()
# tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.generation_config.pad_token_id = model.generation_config.eos_token_id
@ -212,10 +236,14 @@ def train(
tokenized_full_prompt = tokenizer(
full_prompt,
truncation=True,
max_length=128,
padding="max_length",
return_tensors="pt",
max_length=cutoff_len,
padding=True,
return_tensors='pt',
)
# for k,v in tokenized_full_prompt.items(): print("======k,v",k,v,type(k),type(v))
# exit(0)
tokenized_full_prompt = {k: v.squeeze(0) for k, v in tokenized_full_prompt.items()}
@ -233,7 +261,7 @@ def train(
tokenized_full_prompt["static_prefix"] = static_prefix
# print(f"[DEBUG] static_prefix (after clamp): {static_prefix}")
print(f"[DEBUG] tokenizer vocab_size: {tokenizer.vocab_size}")
# print(f"[DEBUG] tokenizer vocab_size: {tokenizer.vocab_size}")
# **处理动态数据**
sensor_values = torch.zeros(3, dtype=torch.float) # **默认值为 Tensor而不是 list**
@ -263,7 +291,7 @@ def train(
# 限制范围,防止异常值
sensor_values = torch.clamp(sensor_values, min=-100, max=100)
print(f"[DEBUG] sensor_values (AFTER FIX): {sensor_values}") # 🔥 打印调试信息
# print(f"[DEBUG] sensor_values (AFTER FIX): {sensor_values}") # 🔥 打印调试信息
if not isinstance(sensor_values, torch.Tensor):
sensor_values = torch.tensor(sensor_values, dtype=torch.float)
@ -271,6 +299,13 @@ def train(
# 最后增加类型检查和转换
for key in tokenized_full_prompt:
if isinstance(tokenized_full_prompt[key], list):
# Convert lists to tensors
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key])
elif isinstance(tokenized_full_prompt[key], torch.Tensor) and tokenized_full_prompt[key].dim() > 1:
# Squeeze extra dimensions if needed
tokenized_full_prompt[key] = tokenized_full_prompt[key].squeeze(0)
if key in ["input_ids", "attention_mask"] and isinstance(tokenized_full_prompt[key], list):
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key], dtype=torch.long)
@ -296,6 +331,7 @@ def train(
return tokenized_full_prompt
# 创建PrefixTuning配置
prefix_config = PrefixTuningConfig(
@ -412,6 +448,29 @@ def train(
except Exception as e:
print(f"[WARNING] 保存PEFT模型时出错: {e}")
def inspect_model_structure(model):
"""检查模型结构并打印关键层信息"""
print(f"Model type: {type(model).__name__}")
print(f"Model config: {model.config.__class__.__name__}")
# 检查嵌入层
embedding_layers = []
for name, module in model.named_modules():
if any(key in name for key in ['embed', 'wte', 'word_embeddings']):
embedding_layers.append((name, type(module).__name__))
if hasattr(module, 'weight'):
print(f"Layer {name}: shape {module.weight.shape}")
print(f"Found {len(embedding_layers)} potential embedding layers:")
for name, type_name in embedding_layers:
print(f" - {name}: {type_name}")
# 检查注意力层
print("\nAttention structure:")
for name, module in model.named_modules():
if 'attention' in name.lower():
print(f" - {name}: {type(module).__name__}")
if __name__ == "__main__":
fire.Fire(train)

101
kopa.py
View File

@ -8,18 +8,24 @@ from transformers import LlamaForCausalLM
class KoPA(nn.Module):
def __init__(
self,
model: LlamaForCausalLM
model
) -> None:
super(KoPA, self).__init__()
self.llama_model = model
self.embeddings = nn.Embedding(100, 3072)
# self.embeddings = PrefixKGEmbedding(
# num_ent=2034,
# num_rel=42,
# dim_llm=3072,
# num_prefix=1
# )
for param in self.model.parameters():
param.requires_grad = False
# Only keep gradients for the adapter parts
self.num_prefix = num_prefix
hidden_size = model.config.hidden_size
self.embeddings = nn.Embedding(100, 4096)
for param in model.parameters():
param.requires_grad = False
# Only enable gradients for adapter components
self.static_prefix_embedding.requires_grad_(True)
self.sensor_mlp.requires_grad_(True)
self.norm.requires_grad_(True)
def forward(
self,
input_ids: torch.LongTensor = None,
@ -39,7 +45,16 @@ class KoPA(nn.Module):
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
kg_embeds = self.embeddings(embedding_ids)
batch_size, seq_len, _ = kg_embeds.shape
if hasattr(self.llama_model, 'transformer'):
# Qwen模型
token_embeds = self.llama_model.transformer.wte(input_ids)
elif hasattr(self.llama_model, 'model') and hasattr(self.llama_model.model, 'embed_tokens'):
# 原始路径
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
else:
# 添加调试代码
print("无法找到模型嵌入层,尝试检测模型结构...")
raise ValueError("模型结构不兼容")
input_embeds = torch.cat((kg_embeds, token_embeds), dim=1)
prefix_mask = torch.ones((batch_size, seq_len))
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
@ -70,8 +85,12 @@ class KoPAWithAdapter(nn.Module):
self.num_prefix = num_prefix
hidden_size = model.config.hidden_size
# 打印模型信息以便调试
print(f"[INFO] 初始化KoPAWithAdapter模型类型: {type(model).__name__}")
# 使用tokenizer获取vocab_size
vocab_size = tokenizer.vocab_size if tokenizer else 32000
vocab_size = tokenizer.vocab_size if tokenizer else 151936 # Qwen2.5的默认词表大小
print(f"[INFO] 使用词表大小: {vocab_size}")
self.static_prefix_embedding = nn.Embedding(vocab_size, hidden_size)
self.embeddings = self.static_prefix_embedding # 保留这个属性
@ -88,6 +107,26 @@ class KoPAWithAdapter(nn.Module):
print(f"[INFO] 模型初始化: hidden_size={hidden_size}, vocab_size={vocab_size}")
# 检测模型嵌入层路径
self._detect_embedding_path()
def _detect_embedding_path(self):
"""检测模型的嵌入层路径"""
self.embedding_path = None
# 尝试不同的常见路径
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'):
self.embedding_path = "transformer.wte"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'):
self.embedding_path = "model.embed_tokens"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'):
self.embedding_path = "model.model.model.embed_tokens"
if self.embedding_path:
print(f"[INFO] 检测到嵌入层路径: {self.embedding_path}")
else:
print("[WARNING] 无法自动检测嵌入层路径,将在前向传播中尝试多种路径")
def forward(self, input_ids, attention_mask, static_prefix=None, sensor_data=None, labels=None, **kwargs):
batch_size, seq_len = input_ids.shape
device = input_ids.device
@ -128,8 +167,37 @@ class KoPAWithAdapter(nn.Module):
final_prefix = alpha * static_prefix + (1 - alpha) * dynamic_prefix
final_prefix = self.norm(final_prefix)
# 处理token嵌入
# 处理token嵌入 - 根据检测到的路径获取嵌入
try:
if self.embedding_path == "transformer.wte":
token_embeds = self.model.transformer.wte(input_ids)
elif self.embedding_path == "model.embed_tokens":
token_embeds = self.model.model.embed_tokens(input_ids)
elif self.embedding_path == "model.model.model.embed_tokens":
token_embeds = self.model.model.model.embed_tokens(input_ids)
else:
# 尝试多种可能的路径
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'):
token_embeds = self.model.transformer.wte(input_ids)
self.embedding_path = "transformer.wte"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'embed_tokens'):
token_embeds = self.model.model.embed_tokens(input_ids)
self.embedding_path = "model.embed_tokens"
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'model') and hasattr(self.model.model.model, 'embed_tokens'):
token_embeds = self.model.model.model.embed_tokens(input_ids)
self.embedding_path = "model.model.model.embed_tokens"
else:
raise ValueError("无法找到嵌入层路径")
print(f"[INFO] 成功找到嵌入层路径: {self.embedding_path}")
except Exception as e:
print(f"[ERROR] 获取token嵌入失败: {e}")
# 打印模型结构以帮助调试
print("模型结构:")
for name, _ in self.model.named_modules():
if 'embed' in name or 'wte' in name:
print(f" - {name}")
raise
input_embeds = torch.cat((final_prefix, token_embeds), dim=1)
# 扩展注意力掩码
@ -140,7 +208,7 @@ class KoPAWithAdapter(nn.Module):
)
extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
# ✨ 关键修复: 处理标签
# 处理标签
if labels is not None:
# 为前缀部分创建-100的标签表示忽略
prefix_labels = torch.full(
@ -154,22 +222,15 @@ class KoPAWithAdapter(nn.Module):
else:
extended_labels = None
# 调试输出
# print(f"[DEBUG] 原始输入大小: {input_ids.shape}")
# print(f"[DEBUG] 扩展嵌入大小: {input_embeds.shape}")
# print(f"[DEBUG] 扩展掩码大小: {extended_attention_mask.shape}")
# if extended_labels is not None:
# print(f"[DEBUG] 扩展标签大小: {extended_labels.shape}")
# 确保不提供input_ids
if 'input_ids' in kwargs:
del kwargs['input_ids']
# 传递扩展后的标签
# 传递扩展后的标签
return self.model(
inputs_embeds=input_embeds,
attention_mask=extended_attention_mask,
labels=extended_labels, # 这是关键修改
labels=extended_labels,
use_cache=False,
**kwargs)

Binary file not shown.

Binary file not shown.