111 lines
4.2 KiB
Python
111 lines
4.2 KiB
Python
from transformers import AutoModelForCausalLM, AutoTokenizer,DataCollatorWithPadding
|
||
from qwen_generation_utils import decode_tokens
|
||
import torch
|
||
import datasets
|
||
|
||
|
||
model_dir = './model'
|
||
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", padding_side='left',trust_remote_code=True)
|
||
# Set `torch_dtype=torch.float16` to load model in float16, otherwise it will be loaded as float32 and might cause OOM Error.
|
||
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto",pad_token_id=tokenizer.eos_token_id, trust_remote_code=True, torch_dtype=torch.float16)
|
||
# (Optional) If on low resource devices, you can load model in 4-bit or 8-bit to further save GPU memory via bitsandbytes.
|
||
# InternLM 7B in 4bit will cost nearly 8GB GPU memory.
|
||
# pip install -U bitsandbytes
|
||
# 8-bit: model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, load_in_8bit=True)
|
||
# 4-bit: model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, load_in_4bit=True)
|
||
model = model.eval()
|
||
|
||
# # convert data
|
||
# import ujson
|
||
# def transform_conversation_data(raw_data):
|
||
# try:
|
||
# instruction = '<|im_start|>system\n'+raw_data.get("conversation", "")[0]['system'] + "<|im_end|>\n"
|
||
|
||
# conversation = raw_data.get("conversation", [])
|
||
# for i, dialog in enumerate(conversation):
|
||
# instruction += "<|im_start|>user\n来访者:" + dialog["input"]+ "<|im_end|>\n"
|
||
|
||
# if i < len(conversation) - 1:
|
||
# instruction += "<|im_start|>assistant\n医生:" + dialog["output"]+"<|im_end|>\n"
|
||
|
||
# response = conversation[-1]["output"] if conversation else ""
|
||
|
||
# instruction +="<|im_start|>assistant\n医生:"
|
||
|
||
# return {"instruction": instruction, "output": response}
|
||
|
||
# except Exception as e:
|
||
# pass
|
||
|
||
|
||
# with open(f'./data_dir/data.json', 'r', encoding='utf-8') as f1:
|
||
# data = ujson.load(f1)
|
||
# with open(f'./data_dir/converted.json', 'w', encoding='utf-8') as f:
|
||
# for j, item in enumerate(data):
|
||
# temp=transform_conversation_data(item)
|
||
# if temp:
|
||
# transformed_data =ujson.dumps(temp, ensure_ascii=False)
|
||
# f.write(transformed_data+'\n')
|
||
|
||
#set test params
|
||
|
||
|
||
#set test params
|
||
test_num=1596 #测试数据条数
|
||
batch_size=12
|
||
|
||
|
||
#prepare data and dataloader
|
||
dataset = datasets.load_dataset('json', data_files='./data_dir/converted.json',split=f"train[:{test_num}]")
|
||
references =dataset['output'][:test_num]
|
||
|
||
hypotheses = []
|
||
def preprocess(data):
|
||
length = list(map(len, data['instruction']))
|
||
model_inputs=tokenizer(data['instruction'], max_length=512, truncation=True )
|
||
labels=tokenizer(data['output'], padding=True,max_length=128, truncation=True )
|
||
model_inputs['labels']=labels['input_ids']
|
||
model_inputs['length'] = length
|
||
return model_inputs
|
||
preprocessed_dataset = dataset.map(preprocess, batched=True,remove_columns=['instruction', 'output',])
|
||
|
||
|
||
collator=DataCollatorWithPadding(tokenizer=tokenizer,)
|
||
from torch.utils.data import DataLoader
|
||
|
||
dataloader = DataLoader(preprocessed_dataset, batch_size=batch_size, collate_fn=collator)
|
||
|
||
#generate responses
|
||
stop_word="<|im_end|>"
|
||
for batch in dataloader:
|
||
batch_input_ids = torch.LongTensor(batch['input_ids']).to(model.device)
|
||
batch_labels = batch['labels']
|
||
attention_mask = batch['attention_mask']
|
||
length = batch['length']
|
||
batch_out_ids = model.generate(
|
||
batch_input_ids.to(model.device),
|
||
return_dict_in_generate=False,
|
||
max_new_tokens=256,
|
||
do_sample=True,
|
||
temperature=0.1,
|
||
eos_token_id=92542
|
||
)
|
||
|
||
padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))]
|
||
batch_response = [
|
||
decode_tokens(
|
||
batch_out_ids[i][padding_lens[i]:],
|
||
tokenizer,
|
||
context_length=0,
|
||
raw_text_len=length[i],
|
||
chat_format="raw",
|
||
verbose=False,
|
||
errors='replace'
|
||
).replace("医生:","") for i in range(batch_size)]
|
||
hypotheses.extend([r.replace(stop_word," ").split()[0] for r in batch_response if stop_word in r])
|
||
|
||
|
||
# Load metric
|
||
from metric import compute_metrics
|
||
|
||
print(compute_metrics((hypotheses,references))) |