111 lines
4.2 KiB
Python
111 lines
4.2 KiB
Python
|
from transformers import AutoModelForCausalLM, AutoTokenizer,DataCollatorWithPadding
|
|||
|
from qwen_generation_utils import decode_tokens
|
|||
|
import torch
|
|||
|
import datasets
|
|||
|
|
|||
|
|
|||
|
model_dir = './model'
|
|||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", padding_side='left',trust_remote_code=True)
|
|||
|
# Set `torch_dtype=torch.float16` to load model in float16, otherwise it will be loaded as float32 and might cause OOM Error.
|
|||
|
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto",pad_token_id=tokenizer.eos_token_id, trust_remote_code=True, torch_dtype=torch.float16)
|
|||
|
# (Optional) If on low resource devices, you can load model in 4-bit or 8-bit to further save GPU memory via bitsandbytes.
|
|||
|
# InternLM 7B in 4bit will cost nearly 8GB GPU memory.
|
|||
|
# pip install -U bitsandbytes
|
|||
|
# 8-bit: model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, load_in_8bit=True)
|
|||
|
# 4-bit: model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, load_in_4bit=True)
|
|||
|
model = model.eval()
|
|||
|
|
|||
|
# # convert data
|
|||
|
# import ujson
|
|||
|
# def transform_conversation_data(raw_data):
|
|||
|
# try:
|
|||
|
# instruction = '<|im_start|>system\n'+raw_data.get("conversation", "")[0]['system'] + "<|im_end|>\n"
|
|||
|
|
|||
|
# conversation = raw_data.get("conversation", [])
|
|||
|
# for i, dialog in enumerate(conversation):
|
|||
|
# instruction += "<|im_start|>user\n来访者:" + dialog["input"]+ "<|im_end|>\n"
|
|||
|
|
|||
|
# if i < len(conversation) - 1:
|
|||
|
# instruction += "<|im_start|>assistant\n医生:" + dialog["output"]+"<|im_end|>\n"
|
|||
|
|
|||
|
# response = conversation[-1]["output"] if conversation else ""
|
|||
|
|
|||
|
# instruction +="<|im_start|>assistant\n医生:"
|
|||
|
|
|||
|
# return {"instruction": instruction, "output": response}
|
|||
|
|
|||
|
# except Exception as e:
|
|||
|
# pass
|
|||
|
|
|||
|
|
|||
|
# with open(f'./data_dir/data.json', 'r', encoding='utf-8') as f1:
|
|||
|
# data = ujson.load(f1)
|
|||
|
# with open(f'./data_dir/converted.json', 'w', encoding='utf-8') as f:
|
|||
|
# for j, item in enumerate(data):
|
|||
|
# temp=transform_conversation_data(item)
|
|||
|
# if temp:
|
|||
|
# transformed_data =ujson.dumps(temp, ensure_ascii=False)
|
|||
|
# f.write(transformed_data+'\n')
|
|||
|
|
|||
|
#set test params
|
|||
|
|
|||
|
|
|||
|
#set test params
|
|||
|
test_num=1596 #测试数据条数
|
|||
|
batch_size=12
|
|||
|
|
|||
|
|
|||
|
#prepare data and dataloader
|
|||
|
dataset = datasets.load_dataset('json', data_files='./data_dir/converted.json',split=f"train[:{test_num}]")
|
|||
|
references =dataset['output'][:test_num]
|
|||
|
|
|||
|
hypotheses = []
|
|||
|
def preprocess(data):
|
|||
|
length = list(map(len, data['instruction']))
|
|||
|
model_inputs=tokenizer(data['instruction'], max_length=512, truncation=True )
|
|||
|
labels=tokenizer(data['output'], padding=True,max_length=128, truncation=True )
|
|||
|
model_inputs['labels']=labels['input_ids']
|
|||
|
model_inputs['length'] = length
|
|||
|
return model_inputs
|
|||
|
preprocessed_dataset = dataset.map(preprocess, batched=True,remove_columns=['instruction', 'output',])
|
|||
|
|
|||
|
|
|||
|
collator=DataCollatorWithPadding(tokenizer=tokenizer,)
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
|
|||
|
dataloader = DataLoader(preprocessed_dataset, batch_size=batch_size, collate_fn=collator)
|
|||
|
|
|||
|
#generate responses
|
|||
|
stop_word="<|im_end|>"
|
|||
|
for batch in dataloader:
|
|||
|
batch_input_ids = torch.LongTensor(batch['input_ids']).to(model.device)
|
|||
|
batch_labels = batch['labels']
|
|||
|
attention_mask = batch['attention_mask']
|
|||
|
length = batch['length']
|
|||
|
batch_out_ids = model.generate(
|
|||
|
batch_input_ids.to(model.device),
|
|||
|
return_dict_in_generate=False,
|
|||
|
max_new_tokens=256,
|
|||
|
do_sample=True,
|
|||
|
temperature=0.1,
|
|||
|
eos_token_id=92542
|
|||
|
)
|
|||
|
|
|||
|
padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))]
|
|||
|
batch_response = [
|
|||
|
decode_tokens(
|
|||
|
batch_out_ids[i][padding_lens[i]:],
|
|||
|
tokenizer,
|
|||
|
context_length=0,
|
|||
|
raw_text_len=length[i],
|
|||
|
chat_format="raw",
|
|||
|
verbose=False,
|
|||
|
errors='replace'
|
|||
|
).replace("医生:","") for i in range(batch_size)]
|
|||
|
hypotheses.extend([r.replace(stop_word," ").split()[0] for r in batch_response if stop_word in r])
|
|||
|
|
|||
|
|
|||
|
# Load metric
|
|||
|
from metric import compute_metrics
|
|||
|
|
|||
|
print(compute_metrics((hypotheses,references)))
|