DSP-LLAMA初步微调
This commit is contained in:
parent
45b7a67876
commit
1b3dd9475c
122309
data/olive_dataset.json
Normal file
122309
data/olive_dataset.json
Normal file
File diff suppressed because it is too large
Load Diff
368
finetune_kopa.py
368
finetune_kopa.py
@ -6,7 +6,8 @@ import fire
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from kopa import KoPA, KoPAWithAdapter
|
||||
|
||||
from kopa import KoPAWithAdapter
|
||||
|
||||
"""
|
||||
Unused imports:
|
||||
@ -14,55 +15,100 @@ import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
"""
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
from peft import PrefixTuningConfig, get_peft_model
|
||||
from transformers import LlamaForCausalLM, AutoTokenizer
|
||||
|
||||
from utils.prompter import Prompter
|
||||
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
input_ids_list = []
|
||||
attention_mask_list = []
|
||||
static_prefix_list = []
|
||||
sensor_data_list = []
|
||||
|
||||
for b in batch:
|
||||
# 确保输入是张量
|
||||
if isinstance(b["input_ids"], list):
|
||||
input_ids = torch.tensor(b["input_ids"], dtype=torch.long)
|
||||
else:
|
||||
input_ids = b["input_ids"]
|
||||
input_ids_list.append(input_ids)
|
||||
|
||||
if isinstance(b["attention_mask"], list):
|
||||
attention_mask = torch.tensor(b["attention_mask"], dtype=torch.long)
|
||||
else:
|
||||
attention_mask = b["attention_mask"]
|
||||
attention_mask_list.append(attention_mask)
|
||||
|
||||
if "static_prefix" in b:
|
||||
if isinstance(b["static_prefix"], list):
|
||||
static_prefix = torch.tensor(b["static_prefix"], dtype=torch.long)
|
||||
else:
|
||||
static_prefix = b["static_prefix"]
|
||||
static_prefix_list.append(static_prefix)
|
||||
|
||||
if "sensor_data" in b:
|
||||
if isinstance(b["sensor_data"], list):
|
||||
sensor_data = torch.tensor(b["sensor_data"], dtype=torch.float)
|
||||
else:
|
||||
sensor_data = b["sensor_data"]
|
||||
sensor_data_list.append(sensor_data)
|
||||
|
||||
# 堆叠数据
|
||||
result = {
|
||||
"input_ids": torch.stack(input_ids_list),
|
||||
"attention_mask": torch.stack(attention_mask_list),
|
||||
}
|
||||
|
||||
if static_prefix_list:
|
||||
result["static_prefix"] = torch.stack(static_prefix_list)
|
||||
|
||||
if sensor_data_list:
|
||||
result["sensor_data"] = torch.stack(sensor_data_list)
|
||||
|
||||
if "labels" in batch[0]:
|
||||
labels_list = []
|
||||
for b in batch:
|
||||
if isinstance(b["labels"], list):
|
||||
labels = torch.tensor(b["labels"], dtype=torch.long)
|
||||
else:
|
||||
labels = b["labels"]
|
||||
labels_list.append(labels)
|
||||
|
||||
result["labels"] = torch.stack(labels_list)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def train(
|
||||
# model/data params
|
||||
base_model = "models/Llama-3.2-3B-Instruct",
|
||||
data_path: str = "data/CoDeX-S-train.json",
|
||||
output_dir: str = "output",
|
||||
# training hyperparams
|
||||
batch_size: int = 16,
|
||||
micro_batch_size: int = 16,
|
||||
num_epochs: int = 2,
|
||||
learning_rate: float = 3e-4,
|
||||
cutoff_len: int = 512,
|
||||
val_set_size: int = 0,
|
||||
# lora hyperparams
|
||||
lora_r: int = 16,
|
||||
lora_alpha: int = 16,
|
||||
lora_dropout: float = 0.05,
|
||||
lora_target_modules: List[str] = [
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
],
|
||||
num_prefix: int = 1,
|
||||
# llm hyperparams
|
||||
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||
add_eos_token: bool = False,
|
||||
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
||||
# wandb params
|
||||
wandb_project: str = "",
|
||||
wandb_run_name: str = "",
|
||||
wandb_watch: str = "", # options: false | gradients | all
|
||||
wandb_log_model: str = "", # options: false | true
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
||||
kge_model: str = "data/CoDeX-S.pth"
|
||||
# model/data params
|
||||
base_model="models/Llama-3.2-3B-Instruct",
|
||||
data_path: str = "data/CoDeX-S-train.json",
|
||||
output_dir: str = "output",
|
||||
# training hyperparams
|
||||
batch_size: int = 16,
|
||||
micro_batch_size: int = 16,
|
||||
num_epochs: int = 2,
|
||||
learning_rate: float = 3e-4,
|
||||
cutoff_len: int = 512,
|
||||
val_set_size: int = 0,
|
||||
num_prefix: int = 1,
|
||||
# llm hyperparams
|
||||
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||
add_eos_token: bool = False,
|
||||
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
||||
# wandb params
|
||||
wandb_project: str = "",
|
||||
wandb_run_name: str = "",
|
||||
wandb_watch: str = "", # options: false | gradients | all
|
||||
wandb_log_model: str = "", # options: false | true
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
||||
):
|
||||
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
||||
print(
|
||||
f"Training Alpaca-LoRA model with params:\n"
|
||||
f"Training Alpaca model with params:\n"
|
||||
f"base_model: {base_model}\n"
|
||||
f"data_path: {data_path}\n"
|
||||
f"output_dir: {output_dir}\n"
|
||||
@ -72,11 +118,6 @@ def train(
|
||||
f"learning_rate: {learning_rate}\n"
|
||||
f"cutoff_len: {cutoff_len}\n"
|
||||
f"val_set_size: {val_set_size}\n"
|
||||
f"lora_r: {lora_r}\n"
|
||||
f"num_prefix: {num_prefix}\n"
|
||||
f"lora_alpha: {lora_alpha}\n"
|
||||
f"lora_dropout: {lora_dropout}\n"
|
||||
f"lora_target_modules: {lora_target_modules}\n"
|
||||
f"train_on_inputs: {train_on_inputs}\n"
|
||||
f"add_eos_token: {add_eos_token}\n"
|
||||
f"group_by_length: {group_by_length}\n"
|
||||
@ -86,7 +127,6 @@ def train(
|
||||
f"wandb_log_model: {wandb_log_model}\n"
|
||||
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
||||
f"prompt template: {prompt_template_name}\n"
|
||||
f"kge model: {kge_model}\n"
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
@ -102,7 +142,6 @@ def train(
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
||||
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
# load_in_8bit=True,
|
||||
@ -110,34 +149,57 @@ def train(
|
||||
device_map=device_map,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model,use_fast=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
|
||||
tokenizer.pad_token_id = (
|
||||
0 # unk. we want this to be different from the eos token
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# tokenizer.pad_token_id = (
|
||||
# 0 # unk. we want this to be different from the eos token
|
||||
# )
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
# there's probably a way to do this with the tokenizer settings
|
||||
# but again, gotta move fast
|
||||
result = tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=cutoff_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < cutoff_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
def ensure_consistent_keys(dataset):
|
||||
all_keys = set()
|
||||
for example in dataset:
|
||||
all_keys.update(example.keys())
|
||||
|
||||
return result
|
||||
for example in dataset:
|
||||
for key in all_keys:
|
||||
if key not in example:
|
||||
if key == "static_prefix":
|
||||
example[key] = ""
|
||||
elif key == "sensor_data":
|
||||
example[key] = [0, 0, 0]
|
||||
|
||||
return dataset
|
||||
|
||||
# def tokenize(prompt, add_eos_token=True):
|
||||
# # there's probably a way to do this with the tokenizer settings
|
||||
# # but again, gotta move fast
|
||||
# result = tokenizer(
|
||||
# prompt,
|
||||
# truncation=True,
|
||||
# max_length=cutoff_len,
|
||||
# padding=False,
|
||||
# return_tensors=None,
|
||||
# )
|
||||
# if (
|
||||
# result["input_ids"][-1] != tokenizer.eos_token_id
|
||||
# and len(result["input_ids"]) < cutoff_len
|
||||
# and add_eos_token
|
||||
# ):
|
||||
# result["input_ids"].append(tokenizer.eos_token_id)
|
||||
# result["attention_mask"].append(1)
|
||||
#
|
||||
# result["labels"] = result["input_ids"].copy()
|
||||
#
|
||||
# return result
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
full_prompt = prompter.generate_prompt(
|
||||
@ -145,38 +207,114 @@ def train(
|
||||
data_point["input"],
|
||||
data_point["output"],
|
||||
)
|
||||
tokenized_full_prompt = tokenize(full_prompt)
|
||||
|
||||
# Tokenizer 处理文本
|
||||
tokenized_full_prompt = tokenizer(
|
||||
full_prompt,
|
||||
truncation=True,
|
||||
max_length=128,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
tokenized_full_prompt = {k: v.squeeze(0) for k, v in tokenized_full_prompt.items()}
|
||||
|
||||
# 处理静态前缀
|
||||
static_prefix = tokenizer(
|
||||
data_point["instruction"],
|
||||
truncation=True,
|
||||
max_length=10,
|
||||
padding="max_length",
|
||||
return_tensors="pt"
|
||||
)["input_ids"].squeeze(0)
|
||||
|
||||
# 限制索引范围,确保 `static_prefix` 不会超出 `vocab_size`
|
||||
static_prefix = torch.clamp(static_prefix, min=0, max=tokenizer.vocab_size - 1)
|
||||
|
||||
tokenized_full_prompt["static_prefix"] = static_prefix
|
||||
# print(f"[DEBUG] static_prefix (after clamp): {static_prefix}")
|
||||
print(f"[DEBUG] tokenizer vocab_size: {tokenizer.vocab_size}")
|
||||
|
||||
# **处理动态数据**
|
||||
sensor_values = torch.zeros(3, dtype=torch.float) # **默认值为 Tensor,而不是 list**
|
||||
|
||||
if data_point["type"] == "dynamic" and "sensor_data" in data_point:
|
||||
raw_sensor_values = data_point["sensor_data"]
|
||||
|
||||
try:
|
||||
sensor_values = torch.tensor([
|
||||
float(raw_sensor_values.get("temperature", 0.0)),
|
||||
float(raw_sensor_values.get("humidity", 0.0)),
|
||||
float(raw_sensor_values.get("conductivity", 0.0))
|
||||
], dtype=torch.float)
|
||||
except Exception as e:
|
||||
# print(f"[ERROR] sensor_data 解析错误: {raw_sensor_values}, {e}")
|
||||
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
|
||||
# print(f"[ERROR] NaN/Inf detected in sensor_values: {sensor_values}")
|
||||
sensor_values = torch.zeros(3, dtype=torch.float)
|
||||
|
||||
# ✅ 确保 sensor_values 是 `Tensor`
|
||||
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
|
||||
print(f"[ERROR] NaN/Inf detected in sensor_values")
|
||||
if torch.isnan(sensor_values).any() or torch.isinf(sensor_values).any():
|
||||
print(f"[ERROR] NaN/Inf detected in sensor_values")
|
||||
sensor_values = torch.zeros(3, dtype=torch.float)
|
||||
|
||||
# 限制范围,防止异常值
|
||||
sensor_values = torch.clamp(sensor_values, min=-100, max=100)
|
||||
|
||||
print(f"[DEBUG] sensor_values (AFTER FIX): {sensor_values}") # 🔥 打印调试信息
|
||||
if not isinstance(sensor_values, torch.Tensor):
|
||||
sensor_values = torch.tensor(sensor_values, dtype=torch.float)
|
||||
|
||||
tokenized_full_prompt["sensor_data"] = sensor_values # **确保始终是 Tensor**
|
||||
|
||||
# 最后增加类型检查和转换
|
||||
for key in tokenized_full_prompt:
|
||||
if key in ["input_ids", "attention_mask"] and isinstance(tokenized_full_prompt[key], list):
|
||||
tokenized_full_prompt[key] = torch.tensor(tokenized_full_prompt[key], dtype=torch.long)
|
||||
|
||||
if isinstance(tokenized_full_prompt["static_prefix"], list):
|
||||
tokenized_full_prompt["static_prefix"] = torch.tensor(tokenized_full_prompt["static_prefix"],
|
||||
dtype=torch.long)
|
||||
|
||||
# 确保sensor_data是tensor
|
||||
if not isinstance(tokenized_full_prompt["sensor_data"], torch.Tensor):
|
||||
tokenized_full_prompt["sensor_data"] = torch.tensor(tokenized_full_prompt["sensor_data"], dtype=torch.float)
|
||||
|
||||
tokenized_full_prompt["labels"] = tokenized_full_prompt["input_ids"].clone()
|
||||
|
||||
# 如果不想对输入部分计算损失,可以将输入部分的标签设为-100
|
||||
if not train_on_inputs:
|
||||
user_prompt = prompter.generate_prompt(
|
||||
data_point["instruction"], data_point["input"]
|
||||
)
|
||||
tokenized_user_prompt = tokenize(
|
||||
user_prompt, add_eos_token=add_eos_token
|
||||
)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
# 找到用户输入和助手输出的分界点
|
||||
sep = tokenizer.encode(prompter.separator)
|
||||
instruction_tokens = tokenizer.encode(data_point["instruction"])
|
||||
|
||||
if add_eos_token:
|
||||
user_prompt_len -= 1
|
||||
# 将用户输入部分的标签设为-100
|
||||
sep_pos = tokenized_full_prompt["input_ids"].tolist().index(sep[0])
|
||||
tokenized_full_prompt["labels"][:sep_pos] = -100
|
||||
|
||||
tokenized_full_prompt["labels"] = [
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][
|
||||
user_prompt_len:
|
||||
] # could be sped up, probably
|
||||
return tokenized_full_prompt
|
||||
|
||||
# model = prepare_model_for_int8_training(model)
|
||||
# 创建PrefixTuning配置
|
||||
|
||||
config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
prefix_config = PrefixTuningConfig(
|
||||
num_virtual_tokens=num_prefix,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
slama_model = KoPAWithAdapter(model, num_prefix, kge_model=kge_model)
|
||||
|
||||
# 创建PEFT模型
|
||||
peft_model = get_peft_model(model, prefix_config)
|
||||
|
||||
|
||||
# 创建最终的KoPAWithAdapter模型
|
||||
final_model = KoPAWithAdapter(peft_model, num_prefix, tokenizer)
|
||||
device = next(model.parameters()).device
|
||||
print(f"[INFO] 使用设备: {device}")
|
||||
|
||||
# 确保final_model及其组件都在相同设备上
|
||||
final_model = final_model.to(device)
|
||||
|
||||
|
||||
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
|
||||
data = load_dataset("json", data_files=data_path)
|
||||
@ -199,7 +337,6 @@ def train(
|
||||
if os.path.exists(checkpoint_name):
|
||||
print(f"Restarting from {checkpoint_name}")
|
||||
adapters_weights = torch.load(checkpoint_name)
|
||||
set_peft_model_state_dict(model, adapters_weights)
|
||||
else:
|
||||
print(f"Checkpoint {checkpoint_name} not found")
|
||||
|
||||
@ -211,12 +348,15 @@ def train(
|
||||
)
|
||||
train_data = (
|
||||
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
|
||||
)
|
||||
train_data = ensure_consistent_keys(train_data)
|
||||
val_data = (
|
||||
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||
)
|
||||
else:
|
||||
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
train_data = ensure_consistent_keys(train_data)
|
||||
val_data = None
|
||||
|
||||
if not ddp and torch.cuda.device_count() > 1:
|
||||
@ -225,7 +365,8 @@ def train(
|
||||
model.model_parallel = True
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=slama_model,
|
||||
model=final_model,
|
||||
data_collator=custom_collate_fn,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=transformers.TrainingArguments(
|
||||
@ -249,30 +390,27 @@ def train(
|
||||
report_to=None,
|
||||
run_name=None,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
),
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (
|
||||
lambda self, *_, **__: get_peft_model_state_dict(
|
||||
self, old_state_dict()
|
||||
)
|
||||
).__get__(model, type(model))
|
||||
# final_model.config.use_cache = False
|
||||
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
model = torch.compile(model)
|
||||
final_model = torch.compile(model)
|
||||
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
torch.save(slama_model.embeddings, os.path.join(output_dir, "embeddings.pth"))
|
||||
final_model.save_pretrained(output_dir)
|
||||
|
||||
print(
|
||||
"\n If there's a warning about missing keys above, please disregard :)"
|
||||
)
|
||||
# ⭐ 确保embeddings存在再保存
|
||||
if hasattr(final_model, "embeddings"):
|
||||
torch.save(final_model.embeddings, os.path.join(output_dir, "embeddings.pth"))
|
||||
else:
|
||||
print("[WARNING] final_model没有embeddings属性,跳过保存。")
|
||||
|
||||
try:
|
||||
final_model.model.save_pretrained(os.path.join(output_dir, "peft_model"))
|
||||
print(f"[INFO] PEFT模型保存到 {os.path.join(output_dir, 'peft_model')}")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] 保存PEFT模型时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -5,7 +5,7 @@ import transformers
|
||||
from peft import PeftModel
|
||||
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
|
||||
|
||||
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
from transformers import GenerationConfig, LlamaForCausalLM, AutoTokenizer
|
||||
|
||||
base_path = 'YOUR LLM PATH'
|
||||
|
||||
@ -33,7 +33,7 @@ if __name__ == "__main__":
|
||||
embedding_path = "{}/embeddings.pth".format(lora_weights)
|
||||
test_dataset = load_test_dataset(test_data_path)
|
||||
kg_embeddings = torch.load(embedding_path).to(cuda)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_path,use_fast=False)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_path,
|
||||
torch_dtype=torch.float16
|
||||
|
227
kopa.py
227
kopa.py
@ -3,7 +3,6 @@ import torch.nn as nn
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from transformers import LlamaForCausalLM
|
||||
from process_kge import load_pretrain_kge
|
||||
|
||||
|
||||
class KoPA(nn.Module):
|
||||
@ -13,14 +12,14 @@ class KoPA(nn.Module):
|
||||
) -> None:
|
||||
super(KoPA, self).__init__()
|
||||
self.llama_model = model
|
||||
# self.embeddings = nn.Embedding(100, 4096)
|
||||
self.embeddings = PrefixKGEmbedding(
|
||||
num_ent=2034,
|
||||
num_rel=42,
|
||||
dim_llm=3072,
|
||||
num_prefix=1
|
||||
)
|
||||
|
||||
self.embeddings = nn.Embedding(100, 3072)
|
||||
# self.embeddings = PrefixKGEmbedding(
|
||||
# num_ent=2034,
|
||||
# num_rel=42,
|
||||
# dim_llm=3072,
|
||||
# num_prefix=1
|
||||
# )
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -35,6 +34,9 @@ class KoPA(nn.Module):
|
||||
return_dict: Optional[bool] = None,
|
||||
embedding_ids: torch.LongTensor = None
|
||||
):
|
||||
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
|
||||
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
|
||||
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
|
||||
kg_embeds = self.embeddings(embedding_ids)
|
||||
batch_size, seq_len, _ = kg_embeds.shape
|
||||
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
|
||||
@ -43,6 +45,10 @@ class KoPA(nn.Module):
|
||||
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
|
||||
new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1)
|
||||
new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1)
|
||||
if embedding_ids.max() >= self.embeddings.num_embeddings or embedding_ids.min() < 0:
|
||||
print(f"[ERROR] embedding_ids 超出范围!最大值: {embedding_ids.max()}, 最小值: {embedding_ids.min()}")
|
||||
embedding_ids = torch.clamp(embedding_ids, min=0, max=self.embeddings.num_embeddings - 1)
|
||||
|
||||
return self.llama_model(
|
||||
input_ids=None,
|
||||
attention_mask=new_attention_mask,
|
||||
@ -58,87 +64,136 @@ class KoPA(nn.Module):
|
||||
|
||||
|
||||
class KoPAWithAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: LlamaForCausalLM,
|
||||
num_prefix: int,
|
||||
kge_model: str = "data/UMLS-rotate.pth",
|
||||
pretrain_emb_path = None
|
||||
) -> None:
|
||||
super(KoPAWithAdapter, self).__init__()
|
||||
self.llama_model = model
|
||||
ent_embs, rel_embs = load_pretrain_kge(kge_model)
|
||||
if pretrain_emb_path is None:
|
||||
print("Adapter Trained From Scratch".format(pretrain_emb_path))
|
||||
self.embeddings = PretrainKGEmbedding(
|
||||
pretrain_ent_embs=ent_embs,
|
||||
pretrain_rel_embs=rel_embs,
|
||||
dim_llm=3072,
|
||||
num_prefix=num_prefix
|
||||
)
|
||||
else:
|
||||
print("Adapter Load From {}".format(pretrain_emb_path))
|
||||
self.embeddings = torch.load(pretrain_emb_path)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
embedding_ids: torch.LongTensor = None
|
||||
):
|
||||
kg_embeds = self.embeddings(embedding_ids)
|
||||
# print(kg_embeds.shape)
|
||||
batch_size, seq_len, _ = kg_embeds.shape
|
||||
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
|
||||
input_embeds = torch.cat((kg_embeds, token_embeds), dim=1)
|
||||
prefix_mask = torch.ones((batch_size, seq_len))
|
||||
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
|
||||
new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1)
|
||||
new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1)
|
||||
return self.llama_model(
|
||||
input_ids=None,
|
||||
attention_mask=new_attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=input_embeds,
|
||||
labels=new_labels,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
def __init__(self, model, num_prefix, tokenizer=None):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.num_prefix = num_prefix
|
||||
hidden_size = model.config.hidden_size
|
||||
|
||||
# 使用tokenizer获取vocab_size
|
||||
vocab_size = tokenizer.vocab_size if tokenizer else 32000
|
||||
|
||||
self.static_prefix_embedding = nn.Embedding(vocab_size, hidden_size)
|
||||
self.embeddings = self.static_prefix_embedding # 保留这个属性
|
||||
|
||||
self.sensor_mlp = nn.Sequential(
|
||||
nn.Linear(3, hidden_size // 2),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_size // 2, hidden_size)
|
||||
)
|
||||
|
||||
# 添加LayerNorm
|
||||
self.norm = nn.LayerNorm(hidden_size)
|
||||
|
||||
print(f"[INFO] 模型初始化: hidden_size={hidden_size}, vocab_size={vocab_size}")
|
||||
|
||||
class PrefixKGEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_ent,
|
||||
num_rel,
|
||||
dim_llm,
|
||||
num_prefix
|
||||
):
|
||||
super(PrefixKGEmbedding, self).__init__()
|
||||
self.emb_dim = num_prefix * dim_llm
|
||||
self.ent_embeddings = nn.Embedding(num_ent, self.emb_dim)
|
||||
self.rel_embeddings = nn.Embedding(num_rel, self.emb_dim)
|
||||
|
||||
def forward(self, input_ids, attention_mask, static_prefix=None, sensor_data=None, labels=None, **kwargs):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
def forward(self, triple_ids):
|
||||
head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
|
||||
h = self.ent_embeddings(head)
|
||||
r = self.rel_embeddings(relation)
|
||||
t = self.ent_embeddings(tail)
|
||||
prefix = torch.stack((h, r, t), dim=1)
|
||||
return prefix
|
||||
# 确保所有组件在同一设备上
|
||||
self.static_prefix_embedding = self.static_prefix_embedding.to(device)
|
||||
self.sensor_mlp = self.sensor_mlp.to(device)
|
||||
self.norm = self.norm.to(device)
|
||||
|
||||
# 处理静态前缀
|
||||
if static_prefix is not None:
|
||||
static_prefix = static_prefix.to(device)
|
||||
static_prefix = self.static_prefix_embedding(static_prefix)
|
||||
else:
|
||||
static_prefix = torch.zeros(
|
||||
(batch_size, self.num_prefix, self.model.config.hidden_size),
|
||||
device=device
|
||||
)
|
||||
|
||||
# 处理动态前缀
|
||||
if sensor_data is not None:
|
||||
sensor_data = sensor_data.to(device)
|
||||
|
||||
if sensor_data.dim() == 1:
|
||||
sensor_data = sensor_data.unsqueeze(0)
|
||||
|
||||
try:
|
||||
dynamic_prefix = self.sensor_mlp(sensor_data)
|
||||
dynamic_prefix = dynamic_prefix.unsqueeze(1).expand(-1, self.num_prefix, -1)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] sensor_mlp处理失败: {e}")
|
||||
dynamic_prefix = torch.zeros_like(static_prefix)
|
||||
else:
|
||||
dynamic_prefix = torch.zeros_like(static_prefix)
|
||||
|
||||
# 混合前缀
|
||||
alpha = 0.6
|
||||
final_prefix = alpha * static_prefix + (1 - alpha) * dynamic_prefix
|
||||
final_prefix = self.norm(final_prefix)
|
||||
|
||||
# 处理token嵌入
|
||||
token_embeds = self.model.model.embed_tokens(input_ids)
|
||||
input_embeds = torch.cat((final_prefix, token_embeds), dim=1)
|
||||
|
||||
# 扩展注意力掩码
|
||||
prefix_attention_mask = torch.ones(
|
||||
(batch_size, self.num_prefix),
|
||||
dtype=attention_mask.dtype,
|
||||
device=device
|
||||
)
|
||||
extended_attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
||||
|
||||
# ✨ 关键修复: 处理标签
|
||||
if labels is not None:
|
||||
# 为前缀部分创建-100的标签(表示忽略)
|
||||
prefix_labels = torch.full(
|
||||
(batch_size, self.num_prefix),
|
||||
fill_value=-100, # -100表示忽略这些位置的损失
|
||||
dtype=labels.dtype,
|
||||
device=device
|
||||
)
|
||||
# 扩展标签
|
||||
extended_labels = torch.cat((prefix_labels, labels), dim=1)
|
||||
else:
|
||||
extended_labels = None
|
||||
|
||||
# 调试输出
|
||||
# print(f"[DEBUG] 原始输入大小: {input_ids.shape}")
|
||||
# print(f"[DEBUG] 扩展嵌入大小: {input_embeds.shape}")
|
||||
# print(f"[DEBUG] 扩展掩码大小: {extended_attention_mask.shape}")
|
||||
# if extended_labels is not None:
|
||||
# print(f"[DEBUG] 扩展标签大小: {extended_labels.shape}")
|
||||
|
||||
# 确保不提供input_ids
|
||||
if 'input_ids' in kwargs:
|
||||
del kwargs['input_ids']
|
||||
|
||||
# ✨ 传递扩展后的标签
|
||||
return self.model(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=extended_attention_mask,
|
||||
labels=extended_labels, # 这是关键修改
|
||||
use_cache=False,
|
||||
**kwargs)
|
||||
|
||||
# class PrefixKGEmbedding(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# num_ent,
|
||||
# num_rel,
|
||||
# dim_llm,
|
||||
# num_prefix
|
||||
# ):
|
||||
# super(PrefixKGEmbedding, self).__init__()
|
||||
# self.emb_dim = num_prefix * dim_llm
|
||||
# self.ent_embeddings = nn.Embedding(num_ent, self.emb_dim)
|
||||
# self.rel_embeddings = nn.Embedding(num_rel, self.emb_dim)
|
||||
#
|
||||
#
|
||||
# def forward(self, triple_ids):
|
||||
# head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
|
||||
# h = self.ent_embeddings(head)
|
||||
# r = self.rel_embeddings(relation)
|
||||
# t = self.ent_embeddings(tail)
|
||||
# prefix = torch.stack((h, r, t), dim=1)
|
||||
# return prefix
|
||||
|
||||
class PretrainKGEmbedding(nn.Module):
|
||||
def __init__(
|
||||
@ -159,7 +214,7 @@ class PretrainKGEmbedding(nn.Module):
|
||||
self.ent_embeddings.requires_grad_(False)
|
||||
self.rel_embeddings.requires_grad_(False)
|
||||
self.adapter = nn.Linear(self.pretrain_dim, self.emb_dim)
|
||||
|
||||
|
||||
|
||||
def forward(self, triple_ids):
|
||||
# main training stage
|
||||
|
BIN
models/Llama-3.2-3B-Instruct/tokenizer.model
(Stored with Git LFS)
Normal file
BIN
models/Llama-3.2-3B-Instruct/tokenizer.model
(Stored with Git LFS)
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user