OliveSensorAPI/evaluate/metric.py

33 lines
1.3 KiB
Python
Raw Normal View History

2024-02-28 20:14:46 +08:00
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge
import numpy as np
import jieba
def compute_metrics(eval_pred):
predictions, labels = eval_pred
# 字符级别
# decoded_preds = [" ".join((pred.replace(" ", ""))) for pred in predictions]
# decoded_labels = [" ".join((label.replace(" ", ""))) for label in labels]
# 词级别
decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in predictions]
decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in labels]
rouge = Rouge()
2024-03-03 17:20:16 +08:00
bleu =np.array([0.,0.,0.,0.])
weights = [(1.,0.,0.,0.),(1./2., 1./2.),(1./3., 1./3., 1./3.),(1./4., 1./4., 1./4., 1./4.)]
2024-02-28 20:14:46 +08:00
for decoded_label, decoded_pred in zip(decoded_labels, decoded_preds):
bleu +=np.array( sentence_bleu(
references=[decoded_label.split(' ')],
hypothesis=decoded_pred.split(' '),
smoothing_function=SmoothingFunction().method1,weights=weights
))
bleu /= len(decoded_labels)
result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)
result = {key: value['f'] * 100 for key, value in result.items()}
result["bleu"] = {'bleu_1':bleu[0] * 100,'bleu_2':bleu[1] * 100,'bleu_3':bleu[2] * 100,'bleu_4':bleu[3] * 100}
return result