add InterLM2_7B eval

This commit is contained in:
ZeyuBa 2024-03-03 17:20:16 +08:00
parent 84581eb9e0
commit d825895a79
8 changed files with 124 additions and 22 deletions

View File

@ -43,13 +43,9 @@ pip install torch transformers datasets nltk rouge jieba
## 测试结果 ## 测试结果
基于全量微调后的Qwen1_5-0_5B-Chat模型对data.json中的数据进行测试结果如下 对data.json中的数据进行测试结果如下
| Metric | Value |
|---------|----------------------| | Model | ROUGE-1 | ROUGE-2 | ROUGE-L | BLEU-1 | BLEU-2 | BLEU-3 | BLEU-4 |
| ROUGE-1 | 27.23% | |----------|---------|---------|---------|---------|---------|---------|---------|
| ROUGE-2 | 8.55% | | Qwen1_5-0_5B-Chat | 27.23% | 8.55% | 17.05% | 26.65% | 13.11% | 7.19% | 4.05% |
| ROUGE-L | 17.05% | | InternLM2_7B_chat | 37.86% | 15.23% | 24.34% | 39.71% | 22.66% | 14.26% | 9.21% |
| BLEU-1 | 26.65% |
| BLEU-2 | 13.11% |
| BLEU-3 | 7.19% |
| BLEU-4 | 4.05% |

View File

@ -0,0 +1,111 @@
from transformers import AutoModelForCausalLM, AutoTokenizer,DataCollatorWithPadding
from qwen_generation_utils import decode_tokens
import torch
import datasets
model_dir = './model'
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", padding_side='left',trust_remote_code=True)
# Set `torch_dtype=torch.float16` to load model in float16, otherwise it will be loaded as float32 and might cause OOM Error.
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto",pad_token_id=tokenizer.eos_token_id, trust_remote_code=True, torch_dtype=torch.float16)
# (Optional) If on low resource devices, you can load model in 4-bit or 8-bit to further save GPU memory via bitsandbytes.
# InternLM 7B in 4bit will cost nearly 8GB GPU memory.
# pip install -U bitsandbytes
# 8-bit: model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, load_in_8bit=True)
# 4-bit: model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, load_in_4bit=True)
model = model.eval()
# # convert data
# import ujson
# def transform_conversation_data(raw_data):
# try:
# instruction = '<|im_start|>system\n'+raw_data.get("conversation", "")[0]['system'] + "<|im_end|>\n"
# conversation = raw_data.get("conversation", [])
# for i, dialog in enumerate(conversation):
# instruction += "<|im_start|>user\n来访者" + dialog["input"]+ "<|im_end|>\n"
# if i < len(conversation) - 1:
# instruction += "<|im_start|>assistant\n医生" + dialog["output"]+"<|im_end|>\n"
# response = conversation[-1]["output"] if conversation else ""
# instruction +="<|im_start|>assistant\n医生"
# return {"instruction": instruction, "output": response}
# except Exception as e:
# pass
# with open(f'./data_dir/data.json', 'r', encoding='utf-8') as f1:
# data = ujson.load(f1)
# with open(f'./data_dir/converted.json', 'w', encoding='utf-8') as f:
# for j, item in enumerate(data):
# temp=transform_conversation_data(item)
# if temp:
# transformed_data =ujson.dumps(temp, ensure_ascii=False)
# f.write(transformed_data+'\n')
#set test params
#set test params
test_num=1596 #测试数据条数
batch_size=12
#prepare data and dataloader
dataset = datasets.load_dataset('json', data_files='./data_dir/converted.json',split=f"train[:{test_num}]")
references =dataset['output'][:test_num]
hypotheses = []
def preprocess(data):
length = list(map(len, data['instruction']))
model_inputs=tokenizer(data['instruction'], max_length=512, truncation=True )
labels=tokenizer(data['output'], padding=True,max_length=128, truncation=True )
model_inputs['labels']=labels['input_ids']
model_inputs['length'] = length
return model_inputs
preprocessed_dataset = dataset.map(preprocess, batched=True,remove_columns=['instruction', 'output',])
collator=DataCollatorWithPadding(tokenizer=tokenizer,)
from torch.utils.data import DataLoader
dataloader = DataLoader(preprocessed_dataset, batch_size=batch_size, collate_fn=collator)
#generate responses
stop_word="<|im_end|>"
for batch in dataloader:
batch_input_ids = torch.LongTensor(batch['input_ids']).to(model.device)
batch_labels = batch['labels']
attention_mask = batch['attention_mask']
length = batch['length']
batch_out_ids = model.generate(
batch_input_ids.to(model.device),
return_dict_in_generate=False,
max_new_tokens=256,
do_sample=True,
temperature=0.1,
eos_token_id=92542
)
padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))]
batch_response = [
decode_tokens(
batch_out_ids[i][padding_lens[i]:],
tokenizer,
context_length=0,
raw_text_len=length[i],
chat_format="raw",
verbose=False,
errors='replace'
).replace("医生:","") for i in range(batch_size)]
hypotheses.extend([r.replace(stop_word," ").split()[0] for r in batch_response if stop_word in r])
# Load metric
from metric import compute_metrics
print(compute_metrics((hypotheses,references)))

View File

@ -25,7 +25,7 @@ batch_size=12
#prepare data and dataloader #prepare data and dataloader
dataset = datasets.load_dataset('json', data_files='./train_dir/converted.json',split=f"train[:{test_num}]") dataset = datasets.load_dataset('json', data_files='./data_dir/converted.json',split=f"train[:{test_num}]")
references =dataset['output'][:test_num] references =dataset['output'][:test_num]
hypotheses = [] hypotheses = []

View File

@ -4,15 +4,10 @@
* 具体指标、方法见 General evaluation.md * 具体指标、方法见 General evaluation.md
| Metric | Value | | Model | ROUGE-1 | ROUGE-2 | ROUGE-L | BLEU-1 | BLEU-2 | BLEU-3 | BLEU-4 |
|---------|----------------------| |----------|---------|---------|---------|---------|---------|---------|---------|
| ROUGE-1 | 27.23% | | Qwen1_5-0_5B-Chat | 27.23% | 8.55% | 17.05% | 26.65% | 13.11% | 7.19% | 4.05% |
| ROUGE-2 | 8.55% | | InternLM2_7B_chat | 37.86% | 15.23% | 24.34% | 39.71% | 22.66% | 14.26% | 9.21% |
| ROUGE-L | 17.05% |
| BLEU-1 | 26.65% |
| BLEU-2 | 13.11% |
| BLEU-3 | 7.19% |
| BLEU-4 | 4.05% |
## 专业指标评测 ## 专业指标评测

View File

@ -18,8 +18,8 @@ def compute_metrics(eval_pred):
rouge = Rouge() rouge = Rouge()
bleu =np.array([0,0,0,0]) bleu =np.array([0.,0.,0.,0.])
weights = [(1.,0,0,0),(1./2., 1./2.),(1./3., 1./3., 1./3.),(1./4., 1./4., 1./4., 1./4.)] weights = [(1.,0.,0.,0.),(1./2., 1./2.),(1./3., 1./3., 1./3.),(1./4., 1./4., 1./4., 1./4.)]
for decoded_label, decoded_pred in zip(decoded_labels, decoded_preds): for decoded_label, decoded_pred in zip(decoded_labels, decoded_preds):
bleu +=np.array( sentence_bleu( bleu +=np.array( sentence_bleu(
references=[decoded_label.split(' ')], references=[decoded_label.split(' ')],