OliveSensorAPI/evaluate/InternLM2_7B_chat_eval.py
2024-03-03 22:50:54 +08:00

111 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import AutoModelForCausalLM, AutoTokenizer,DataCollatorWithPadding
from qwen_generation_utils import decode_tokens
import torch
import datasets
model_dir = './model'
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", padding_side='left',trust_remote_code=True)
# Set `torch_dtype=torch.float16` to load model in float16, otherwise it will be loaded as float32 and might cause OOM Error.
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto",pad_token_id=tokenizer.eos_token_id, trust_remote_code=True, torch_dtype=torch.float16)
# (Optional) If on low resource devices, you can load model in 4-bit or 8-bit to further save GPU memory via bitsandbytes.
# InternLM 7B in 4bit will cost nearly 8GB GPU memory.
# pip install -U bitsandbytes
# 8-bit: model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, load_in_8bit=True)
# 4-bit: model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, load_in_4bit=True)
model = model.eval()
# # convert data
# import ujson
# def transform_conversation_data(raw_data):
# try:
# instruction = '<|im_start|>system\n'+raw_data.get("conversation", "")[0]['system'] + "<|im_end|>\n"
# conversation = raw_data.get("conversation", [])
# for i, dialog in enumerate(conversation):
# instruction += "<|im_start|>user\n来访者" + dialog["input"]+ "<|im_end|>\n"
# if i < len(conversation) - 1:
# instruction += "<|im_start|>assistant\n医生" + dialog["output"]+"<|im_end|>\n"
# response = conversation[-1]["output"] if conversation else ""
# instruction +="<|im_start|>assistant\n医生"
# return {"instruction": instruction, "output": response}
# except Exception as e:
# pass
# with open(f'./data_dir/data.json', 'r', encoding='utf-8') as f1:
# data = ujson.load(f1)
# with open(f'./data_dir/converted.json', 'w', encoding='utf-8') as f:
# for j, item in enumerate(data):
# temp=transform_conversation_data(item)
# if temp:
# transformed_data =ujson.dumps(temp, ensure_ascii=False)
# f.write(transformed_data+'\n')
#set test params
#set test params
test_num=1596 #测试数据条数
batch_size=12
#prepare data and dataloader
dataset = datasets.load_dataset('json', data_files='./data_dir/converted.json',split=f"train[:{test_num}]")
references =dataset['output'][:test_num]
hypotheses = []
def preprocess(data):
length = list(map(len, data['instruction']))
model_inputs=tokenizer(data['instruction'], max_length=512, truncation=True )
labels=tokenizer(data['output'], padding=True,max_length=128, truncation=True )
model_inputs['labels']=labels['input_ids']
model_inputs['length'] = length
return model_inputs
preprocessed_dataset = dataset.map(preprocess, batched=True,remove_columns=['instruction', 'output',])
collator=DataCollatorWithPadding(tokenizer=tokenizer,)
from torch.utils.data import DataLoader
dataloader = DataLoader(preprocessed_dataset, batch_size=batch_size, collate_fn=collator)
#generate responses
stop_word="<|im_end|>"
for batch in dataloader:
batch_input_ids = torch.LongTensor(batch['input_ids']).to(model.device)
batch_labels = batch['labels']
attention_mask = batch['attention_mask']
length = batch['length']
batch_out_ids = model.generate(
batch_input_ids.to(model.device),
return_dict_in_generate=False,
max_new_tokens=256,
do_sample=True,
temperature=0.1,
eos_token_id=92542
)
padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))]
batch_response = [
decode_tokens(
batch_out_ids[i][padding_lens[i]:],
tokenizer,
context_length=0,
raw_text_len=length[i],
chat_format="raw",
verbose=False,
errors='replace'
).replace("医生:","") for i in range(batch_size)]
hypotheses.extend([r.replace(stop_word," ").split()[0] if stop_word in r else r for r in batch_response])
# Load metric
from metric import compute_metrics
print(compute_metrics((hypotheses,references)))