add evaluation part
This commit is contained in:
		
							parent
							
								
									37161c84d1
								
							
						
					
					
						commit
						26768a2c75
					
				
							
								
								
									
										56
									
								
								evaluate/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								evaluate/README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,56 @@ | |||||||
|  | 
 | ||||||
|  | # EmoLLM通用指标评估 | ||||||
|  | 
 | ||||||
|  | ## 简介 | ||||||
|  | 
 | ||||||
|  | 此 README 文件提供了关于如何使用 `eval.py` 和 `metric.py` 两个脚本的指导。这些脚本用于评估 EmoLLM-心理健康大模型的生成结果。 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ## 安装 | ||||||
|  | 
 | ||||||
|  | - Python 3.x | ||||||
|  | - PyTorch | ||||||
|  | - Transformers  | ||||||
|  | - Datasets  | ||||||
|  | - NLTK  | ||||||
|  | - Rouge  | ||||||
|  | - Jieba  | ||||||
|  | 
 | ||||||
|  | 可以使用以下命令安装: | ||||||
|  | 
 | ||||||
|  | ```bash | ||||||
|  | pip install torch transformers datasets nltk rouge jieba | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | ## 用法 | ||||||
|  | 
 | ||||||
|  | ### convert.py | ||||||
|  | 将原始多轮对话数据转换为测评用的单轮数据。 | ||||||
|  | 
 | ||||||
|  | ### eval.py | ||||||
|  | 
 | ||||||
|  | `eval.py` 脚本用于生成医生的回复并进行评估,主要分为以下几部分: | ||||||
|  | 
 | ||||||
|  | 1. 加载模型和分词器。 | ||||||
|  | 2. 设置测试参数,如测试数据数量和批处理大小。 | ||||||
|  | 3. 准备数据。 | ||||||
|  | 4. 生成响应并评估。 | ||||||
|  | 
 | ||||||
|  | ### metric.py | ||||||
|  | 
 | ||||||
|  | `metric.py` 脚本包含计算评估指标的函数,可设置按字符级别或按词级别进行评估,目前包含 BLEU 和 ROUGE 分数。 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ## 测试结果 | ||||||
|  | 
 | ||||||
|  | 基于全量微调后的Qwen1_5-0_5B-Chat模型对data.json中的数据进行测试,结果如下: | ||||||
|  | | Metric  | Value                | | ||||||
|  | |---------|----------------------| | ||||||
|  | | ROUGE-1 | 27.23%               | | ||||||
|  | | ROUGE-2 | 8.55%                | | ||||||
|  | | ROUGE-L | 17.05%               | | ||||||
|  | | BLEU-1  | 26.65%               | | ||||||
|  | | BLEU-2  | 13.11%               | | ||||||
|  | | BLEU-3  | 7.19%                | | ||||||
|  | | BLEU-4  | 4.05%                | | ||||||
							
								
								
									
										80
									
								
								evaluate/eval.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								evaluate/eval.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,80 @@ | |||||||
|  | from transformers import AutoModelForCausalLM, AutoTokenizer,DataCollatorWithPadding | ||||||
|  | from qwen_generation_utils import  decode_tokens | ||||||
|  | import torch | ||||||
|  | import datasets | ||||||
|  | 
 | ||||||
|  | #load model and tokenizer | ||||||
|  | tokenizer = AutoTokenizer.from_pretrained( | ||||||
|  |     './EmoLLM_Qwen1_5-0_5B-Chat_full_sft', | ||||||
|  |     pad_token='<|extra_0|>', | ||||||
|  |     eos_token='<|endoftext|>', | ||||||
|  |     padding_side='left', | ||||||
|  |     trust_remote_code=True | ||||||
|  | ) | ||||||
|  | model = AutoModelForCausalLM.from_pretrained( | ||||||
|  |     './EmoLLM_Qwen1_5-0_5B-Chat_full_sft', | ||||||
|  |     pad_token_id=tokenizer.pad_token_id, | ||||||
|  |     device_map="cuda:0", | ||||||
|  |     trust_remote_code=True | ||||||
|  | ).eval() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #set test params | ||||||
|  | test_num=1596 #测试数据条数 | ||||||
|  | batch_size=12 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #prepare data and dataloader | ||||||
|  | dataset = datasets.load_dataset('json', data_files='./train_dir/converted.json',split=f"train[:{test_num}]") | ||||||
|  | references =dataset['output'][:test_num] | ||||||
|  | 
 | ||||||
|  | hypotheses = [] | ||||||
|  | def preprocess(data): | ||||||
|  |     length = list(map(len, data['instruction'])) | ||||||
|  |     model_inputs=tokenizer(data['instruction'], max_length=512, truncation=True ) | ||||||
|  |     labels=tokenizer(data['output'], padding=True,max_length=128, truncation=True ) | ||||||
|  |     model_inputs['labels']=labels['input_ids'] | ||||||
|  |     model_inputs['length'] = length | ||||||
|  |     return model_inputs | ||||||
|  | preprocessed_dataset = dataset.map(preprocess, batched=True,remove_columns=['instruction', 'output',]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | collator=DataCollatorWithPadding(tokenizer=tokenizer,) | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | 
 | ||||||
|  | dataloader = DataLoader(preprocessed_dataset, batch_size=batch_size, collate_fn=collator) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #generate responses | ||||||
|  | for batch in dataloader: | ||||||
|  |     batch_input_ids = torch.LongTensor(batch['input_ids']).to(model.device) | ||||||
|  |     batch_labels = batch['labels'] | ||||||
|  |     attention_mask = batch['attention_mask'] | ||||||
|  |     length = batch['length'] | ||||||
|  |     batch_out_ids = model.generate( | ||||||
|  |         batch_input_ids.to(model.device), | ||||||
|  |         return_dict_in_generate=False, | ||||||
|  |         max_new_tokens=256, | ||||||
|  |         temperature=0.1, | ||||||
|  |         pad_token_id=tokenizer.eos_token_id | ||||||
|  |     ) | ||||||
|  |     padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))] | ||||||
|  |     batch_response = [ | ||||||
|  |     decode_tokens( | ||||||
|  |         batch_out_ids[i][padding_lens[i]:], | ||||||
|  |         tokenizer, | ||||||
|  |         context_length=0, | ||||||
|  |         raw_text_len=length[i], | ||||||
|  |         chat_format="raw", | ||||||
|  |         verbose=False, | ||||||
|  |         errors='replace' | ||||||
|  |     ) for i in range(batch_size) | ||||||
|  |     ] | ||||||
|  |     hypotheses.extend(batch_response) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Load metric | ||||||
|  | from metric import compute_metrics | ||||||
|  | 
 | ||||||
|  | print(compute_metrics((hypotheses,references))) | ||||||
							
								
								
									
										33
									
								
								evaluate/metric.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								evaluate/metric.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | |||||||
|  | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction | ||||||
|  | from rouge import Rouge | ||||||
|  | import numpy as np | ||||||
|  | import jieba | ||||||
|  | def compute_metrics(eval_pred): | ||||||
|  |     predictions, labels = eval_pred | ||||||
|  | 
 | ||||||
|  |     # 字符级别 | ||||||
|  |     # decoded_preds = [" ".join((pred.replace(" ", ""))) for pred in predictions] | ||||||
|  |     # decoded_labels = [" ".join((label.replace(" ", ""))) for label in labels] | ||||||
|  | 
 | ||||||
|  |     # 词级别 | ||||||
|  |     decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in predictions] | ||||||
|  |     decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in labels] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     rouge = Rouge() | ||||||
|  | 
 | ||||||
|  |     bleu =np.array([0,0,0,0]) | ||||||
|  |     weights = [(1.,0,0,0),(1./2., 1./2.),(1./3., 1./3., 1./3.),(1./4., 1./4., 1./4., 1./4.)] | ||||||
|  |     for decoded_label, decoded_pred in zip(decoded_labels, decoded_preds): | ||||||
|  |         bleu +=np.array( sentence_bleu( | ||||||
|  |             references=[decoded_label.split(' ')], | ||||||
|  |             hypothesis=decoded_pred.split(' '), | ||||||
|  |             smoothing_function=SmoothingFunction().method1,weights=weights | ||||||
|  |         )) | ||||||
|  |     bleu /= len(decoded_labels) | ||||||
|  |     result = rouge.get_scores(decoded_preds, decoded_labels, avg=True) | ||||||
|  |     result = {key: value['f'] * 100 for key, value in result.items()} | ||||||
|  |     result["bleu"] = {'bleu_1':bleu[0] * 100,'bleu_2':bleu[1] * 100,'bleu_3':bleu[2] * 100,'bleu_4':bleu[3] * 100} | ||||||
|  |     return result | ||||||
							
								
								
									
										416
									
								
								evaluate/qwen_generation_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										416
									
								
								evaluate/qwen_generation_utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,416 @@ | |||||||
|  | # Copyright (c) Alibaba Cloud. | ||||||
|  | # | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  | 
 | ||||||
|  | """Generation support.""" | ||||||
|  | 
 | ||||||
|  | from typing import Tuple, List, Union, Iterable | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from transformers import PreTrainedTokenizer | ||||||
|  | from transformers import logging | ||||||
|  | from transformers.generation import LogitsProcessor | ||||||
|  | 
 | ||||||
|  | logger = logging.get_logger(__name__) | ||||||
|  | 
 | ||||||
|  | # Types. | ||||||
|  | HistoryType = List[Tuple[str, str]] | ||||||
|  | TokensType = List[int] | ||||||
|  | BatchTokensType = List[List[int]] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: | ||||||
|  |     for tokens in batch: | ||||||
|  |         context_length = len(tokens) | ||||||
|  |         if context_length < seq_length: | ||||||
|  |             tokens.extend([pad_id] * (seq_length - context_length)) | ||||||
|  |     return batch | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_ltor_masks_and_position_ids( | ||||||
|  |     data, | ||||||
|  |     eod_token, | ||||||
|  |     reset_position_ids, | ||||||
|  |     reset_attention_mask, | ||||||
|  |     eod_mask_loss, | ||||||
|  | ): | ||||||
|  |     """Build masks and position id for left to right model.""" | ||||||
|  | 
 | ||||||
|  |     # Extract batch size and sequence length. | ||||||
|  |     micro_batch_size, seq_length = data.size() | ||||||
|  | 
 | ||||||
|  |     # Attention mask (lower triangular). | ||||||
|  |     if reset_attention_mask: | ||||||
|  |         att_mask_batch = micro_batch_size | ||||||
|  |     else: | ||||||
|  |         att_mask_batch = 1 | ||||||
|  |     attention_mask = torch.tril( | ||||||
|  |         torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) | ||||||
|  |     ).view(att_mask_batch, 1, seq_length, seq_length) | ||||||
|  | 
 | ||||||
|  |     # Loss mask. | ||||||
|  |     loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) | ||||||
|  |     if eod_mask_loss: | ||||||
|  |         loss_mask[data == eod_token] = 0.0 | ||||||
|  | 
 | ||||||
|  |     # Position ids. | ||||||
|  |     position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) | ||||||
|  |     position_ids = position_ids.unsqueeze(0).expand_as(data) | ||||||
|  |     # We need to clone as the ids will be modifed based on batch index. | ||||||
|  |     if reset_position_ids: | ||||||
|  |         position_ids = position_ids.clone() | ||||||
|  | 
 | ||||||
|  |     if reset_position_ids or reset_attention_mask: | ||||||
|  |         # Loop through the batches: | ||||||
|  |         for b in range(micro_batch_size): | ||||||
|  | 
 | ||||||
|  |             # Find indecies where EOD token is. | ||||||
|  |             eod_index = position_ids[b, data[b] == eod_token] | ||||||
|  |             # Detach indecies from positions if going to modify positions. | ||||||
|  |             if reset_position_ids: | ||||||
|  |                 eod_index = eod_index.clone() | ||||||
|  | 
 | ||||||
|  |             # Loop through EOD indecies: | ||||||
|  |             prev_index = 0 | ||||||
|  |             for j in range(eod_index.size()[0]): | ||||||
|  |                 i = eod_index[j] | ||||||
|  |                 # Mask attention loss. | ||||||
|  |                 if reset_attention_mask: | ||||||
|  |                     attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 | ||||||
|  |                 # Reset positions. | ||||||
|  |                 if reset_position_ids: | ||||||
|  |                     position_ids[b, (i + 1) :] -= i + 1 - prev_index | ||||||
|  |                     prev_index = i + 1 | ||||||
|  | 
 | ||||||
|  |     # Convert attention mask to binary: | ||||||
|  |     attention_mask = attention_mask < 0.5 | ||||||
|  | 
 | ||||||
|  |     return attention_mask, loss_mask, position_ids | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_batch(context_tokens: torch.LongTensor, eod_id: int): | ||||||
|  |     """Generate batch from context tokens.""" | ||||||
|  |     # Move to GPU. | ||||||
|  |     tokens = context_tokens.contiguous().to(context_tokens.device) | ||||||
|  |     # Get the attention mask and postition ids. | ||||||
|  |     attention_mask, _, position_ids = get_ltor_masks_and_position_ids( | ||||||
|  |         tokens, | ||||||
|  |         eod_id, | ||||||
|  |         reset_position_ids=False, | ||||||
|  |         reset_attention_mask=False, | ||||||
|  |         eod_mask_loss=False, | ||||||
|  |     ) | ||||||
|  |     return tokens, attention_mask, position_ids | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_stop_words_ids(chat_format, tokenizer): | ||||||
|  |     if chat_format == "raw": | ||||||
|  |         stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] | ||||||
|  |     elif chat_format == "chatml": | ||||||
|  |         stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] | ||||||
|  |     else: | ||||||
|  |         raise NotImplementedError(f"Unknown chat format {chat_format!r}") | ||||||
|  |     return stop_words_ids | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def make_context( | ||||||
|  |     tokenizer: PreTrainedTokenizer, | ||||||
|  |     query: str, | ||||||
|  |     history: List[Tuple[str, str]] = None, | ||||||
|  |     system: str = "", | ||||||
|  |     max_window_size: int = 6144, | ||||||
|  |     chat_format: str = "chatml", | ||||||
|  | ): | ||||||
|  |     if history is None: | ||||||
|  |         history = [] | ||||||
|  | 
 | ||||||
|  |     if chat_format == "chatml": | ||||||
|  |         im_start, im_end = "<|im_start|>", "<|im_end|>" | ||||||
|  |         im_start_tokens = [tokenizer.im_start_id] | ||||||
|  |         im_end_tokens = [tokenizer.im_end_id] | ||||||
|  |         nl_tokens = tokenizer.encode("\n") | ||||||
|  | 
 | ||||||
|  |         def _tokenize_str(role, content): | ||||||
|  |             return f"{role}\n{content}", tokenizer.encode( | ||||||
|  |                 role, allowed_special=set() | ||||||
|  |             ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) | ||||||
|  | 
 | ||||||
|  |         system_text, system_tokens_part = _tokenize_str("system", system) | ||||||
|  |         system_tokens = im_start_tokens + system_tokens_part + im_end_tokens | ||||||
|  | 
 | ||||||
|  |         raw_text = "" | ||||||
|  |         context_tokens = [] | ||||||
|  | 
 | ||||||
|  |         for turn_query, turn_response in reversed(history): | ||||||
|  |             query_text, query_tokens_part = _tokenize_str("user", turn_query) | ||||||
|  |             query_tokens = im_start_tokens + query_tokens_part + im_end_tokens | ||||||
|  |             response_text, response_tokens_part = _tokenize_str( | ||||||
|  |                 "assistant", turn_response | ||||||
|  |             ) | ||||||
|  |             response_tokens = im_start_tokens + response_tokens_part + im_end_tokens | ||||||
|  | 
 | ||||||
|  |             next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens | ||||||
|  |             prev_chat = ( | ||||||
|  |                 f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             current_context_size = ( | ||||||
|  |                 len(system_tokens) + len(next_context_tokens) + len(context_tokens) | ||||||
|  |             ) | ||||||
|  |             if current_context_size < max_window_size: | ||||||
|  |                 context_tokens = next_context_tokens + context_tokens | ||||||
|  |                 raw_text = prev_chat + raw_text | ||||||
|  |             else: | ||||||
|  |                 break | ||||||
|  | 
 | ||||||
|  |         context_tokens = system_tokens + context_tokens | ||||||
|  |         raw_text = f"{im_start}{system_text}{im_end}" + raw_text | ||||||
|  |         context_tokens += ( | ||||||
|  |             nl_tokens | ||||||
|  |             + im_start_tokens | ||||||
|  |             + _tokenize_str("user", query)[1] | ||||||
|  |             + im_end_tokens | ||||||
|  |             + nl_tokens | ||||||
|  |             + im_start_tokens | ||||||
|  |             + tokenizer.encode("assistant") | ||||||
|  |             + nl_tokens | ||||||
|  |         ) | ||||||
|  |         raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" | ||||||
|  | 
 | ||||||
|  |     elif chat_format == "raw": | ||||||
|  |         raw_text = query | ||||||
|  |         context_tokens = tokenizer.encode(raw_text) | ||||||
|  |     else: | ||||||
|  |         raise NotImplementedError(f"Unknown chat format {chat_format!r}") | ||||||
|  | 
 | ||||||
|  |     return raw_text, context_tokens | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _decode_default( | ||||||
|  |     tokens: List[int], | ||||||
|  |     *, | ||||||
|  |     stop_words: List[str], | ||||||
|  |     eod_words: List[str], | ||||||
|  |     tokenizer: PreTrainedTokenizer, | ||||||
|  |     raw_text_len: int, | ||||||
|  |     verbose: bool = False, | ||||||
|  |     return_end_reason: bool = False, | ||||||
|  |     errors: str='replace', | ||||||
|  | ): | ||||||
|  |     trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] | ||||||
|  |     if verbose: | ||||||
|  |         print("\nRaw Generate: ", trim_decode_tokens) | ||||||
|  | 
 | ||||||
|  |     end_reason = f"Gen length {len(tokens)}" | ||||||
|  |     for stop_word in stop_words: | ||||||
|  |         trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() | ||||||
|  |     for eod_word in eod_words: | ||||||
|  |         if eod_word in trim_decode_tokens: | ||||||
|  |             end_reason = f"Gen {eod_word!r}" | ||||||
|  |         trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] | ||||||
|  |     trim_decode_tokens = trim_decode_tokens.strip() | ||||||
|  |     if verbose: | ||||||
|  |         print("\nEnd Reason:", end_reason) | ||||||
|  |         print("\nGenerate: ", trim_decode_tokens) | ||||||
|  | 
 | ||||||
|  |     if return_end_reason: | ||||||
|  |         return trim_decode_tokens, end_reason | ||||||
|  |     else: | ||||||
|  |         return trim_decode_tokens | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _decode_chatml( | ||||||
|  |     tokens: List[int], | ||||||
|  |     *, | ||||||
|  |     stop_words: List[str], | ||||||
|  |     eod_token_ids: List[int], | ||||||
|  |     tokenizer: PreTrainedTokenizer, | ||||||
|  |     raw_text_len: int, | ||||||
|  |     context_length: int, | ||||||
|  |     verbose: bool = False, | ||||||
|  |     return_end_reason: bool = False, | ||||||
|  |     errors: str='replace' | ||||||
|  | ): | ||||||
|  |     end_reason = f"Gen length {len(tokens)}" | ||||||
|  |     eod_token_idx = context_length | ||||||
|  |     for eod_token_idx in range(context_length, len(tokens)): | ||||||
|  |         if tokens[eod_token_idx] in eod_token_ids: | ||||||
|  |             end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" | ||||||
|  |             break | ||||||
|  | 
 | ||||||
|  |     trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:] | ||||||
|  |     if verbose: | ||||||
|  |         print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:]) | ||||||
|  |         print("\nRaw Generate:", trim_decode_tokens) | ||||||
|  |         print("\nEnd Reason:", end_reason) | ||||||
|  |     for stop_word in stop_words: | ||||||
|  |         trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() | ||||||
|  |     trim_decode_tokens = trim_decode_tokens.strip() | ||||||
|  |     if verbose: | ||||||
|  |         print("\nGenerate:", trim_decode_tokens) | ||||||
|  | 
 | ||||||
|  |     if return_end_reason: | ||||||
|  |         return trim_decode_tokens, end_reason | ||||||
|  |     else: | ||||||
|  |         return trim_decode_tokens | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def decode_tokens( | ||||||
|  |     tokens: Union[torch.LongTensor, TokensType], | ||||||
|  |     tokenizer: PreTrainedTokenizer, | ||||||
|  |     raw_text_len: int, | ||||||
|  |     context_length: int, | ||||||
|  |     chat_format: str, | ||||||
|  |     verbose: bool = False, | ||||||
|  |     return_end_reason: bool = False, | ||||||
|  |     errors: str="replace", | ||||||
|  | ) -> str: | ||||||
|  |     if torch.is_tensor(tokens): | ||||||
|  |         tokens = tokens.cpu().numpy().tolist() | ||||||
|  | 
 | ||||||
|  |     if chat_format == "chatml": | ||||||
|  |         return _decode_chatml( | ||||||
|  |             tokens, | ||||||
|  |             stop_words=[], | ||||||
|  |             eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], | ||||||
|  |             tokenizer=tokenizer, | ||||||
|  |             raw_text_len=raw_text_len, | ||||||
|  |             context_length=context_length, | ||||||
|  |             verbose=verbose, | ||||||
|  |             return_end_reason=return_end_reason, | ||||||
|  |             errors=errors, | ||||||
|  |         ) | ||||||
|  |     elif chat_format == "raw": | ||||||
|  |         return _decode_default( | ||||||
|  |             tokens, | ||||||
|  |             stop_words=["<|endoftext|>"], | ||||||
|  |             eod_words=["<|endoftext|>"], | ||||||
|  |             tokenizer=tokenizer, | ||||||
|  |             raw_text_len=raw_text_len, | ||||||
|  |             verbose=verbose, | ||||||
|  |             return_end_reason=return_end_reason, | ||||||
|  |             errors=errors, | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         raise NotImplementedError(f"Unknown chat format {chat_format!r}") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class StopWordsLogitsProcessor(LogitsProcessor): | ||||||
|  |     """ | ||||||
|  |     :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         stop_words_ids (:obj:`List[List[int]]`): | ||||||
|  |             List of list of token ids of stop ids. In order to get the tokens of the words | ||||||
|  |             that should not appear in the generated text, use :obj:`tokenizer(bad_word, | ||||||
|  |             add_prefix_space=True).input_ids`. | ||||||
|  |         eos_token_id (:obj:`int`): | ||||||
|  |             The id of the `end-of-sequence` token. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): | ||||||
|  | 
 | ||||||
|  |         if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." | ||||||
|  |             ) | ||||||
|  |         if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." | ||||||
|  |             ) | ||||||
|  |         if any( | ||||||
|  |             any( | ||||||
|  |                 (not isinstance(token_id, (int, np.integer)) or token_id < 0) | ||||||
|  |                 for token_id in stop_word_ids | ||||||
|  |             ) | ||||||
|  |             for stop_word_ids in stop_words_ids | ||||||
|  |         ): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         self.stop_words_ids = list( | ||||||
|  |             filter( | ||||||
|  |                 lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.eos_token_id = eos_token_id | ||||||
|  |         for stop_token_seq in self.stop_words_ids: | ||||||
|  |             assert ( | ||||||
|  |                 len(stop_token_seq) > 0 | ||||||
|  |             ), "Stop words token sequences {} cannot have an empty list".format( | ||||||
|  |                 stop_words_ids | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |     def __call__( | ||||||
|  |         self, input_ids: torch.LongTensor, scores: torch.FloatTensor | ||||||
|  |     ) -> torch.FloatTensor: | ||||||
|  |         stopped_samples = self._calc_stopped_samples(input_ids) | ||||||
|  |         for i, should_stop in enumerate(stopped_samples): | ||||||
|  |             if should_stop: | ||||||
|  |                 scores[i, self.eos_token_id] = float(2**15) | ||||||
|  |         return scores | ||||||
|  | 
 | ||||||
|  |     def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: | ||||||
|  |         if len(tokens) == 0: | ||||||
|  |             # if bad word tokens is just one token always ban it | ||||||
|  |             return True | ||||||
|  |         elif len(tokens) > len(prev_tokens): | ||||||
|  |             # if bad word tokens are longer then prev input_ids they can't be equal | ||||||
|  |             return False | ||||||
|  |         elif prev_tokens[-len(tokens) :].tolist() == tokens: | ||||||
|  |             # if tokens match | ||||||
|  |             return True | ||||||
|  |         else: | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |     def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: | ||||||
|  |         stopped_samples = [] | ||||||
|  |         for prev_input_ids_slice in prev_input_ids: | ||||||
|  |             match = False | ||||||
|  |             for stop_token_seq in self.stop_words_ids: | ||||||
|  |                 if self._tokens_match(prev_input_ids_slice, stop_token_seq): | ||||||
|  |                     # if tokens do not match continue | ||||||
|  |                     match = True | ||||||
|  |                     break | ||||||
|  |             stopped_samples.append(match) | ||||||
|  | 
 | ||||||
|  |         return stopped_samples | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): | ||||||
|  |     """This function has been mostly taken from huggingface conversational | ||||||
|  |     ai code at | ||||||
|  |         https://medium.com/huggingface/how-to-build-a-state-of-the-art- | ||||||
|  |              conversational-ai-with-transfer-learning-2d818ac26313""" | ||||||
|  | 
 | ||||||
|  |     if top_k > 0: | ||||||
|  |         # Remove all tokens with a probability less than the | ||||||
|  |         # last token of the top-k | ||||||
|  |         indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | ||||||
|  |         logits[indices_to_remove] = filter_value | ||||||
|  | 
 | ||||||
|  |     if top_p > 0.0: | ||||||
|  |         # Cconvert to 1D | ||||||
|  |         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | ||||||
|  |         cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||||||
|  | 
 | ||||||
|  |         # Remove tokens with cumulative probability above the threshold | ||||||
|  |         sorted_indices_to_remove = cumulative_probs > top_p | ||||||
|  |         # Shift the indices to the right to keep also the first token | ||||||
|  |         # above the threshold | ||||||
|  |         sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | ||||||
|  |         sorted_indices_to_remove[..., 0] = 0 | ||||||
|  |         for i in range(sorted_indices.size(0)): | ||||||
|  |             indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] | ||||||
|  |             logits[i][indices_to_remove] = filter_value | ||||||
|  | 
 | ||||||
|  |     return logits | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def switch(val1, val2, boolean): | ||||||
|  |     boolean = boolean.type_as(val1) | ||||||
|  |     return (1 - boolean) * val1 + boolean * val2 | ||||||
							
								
								
									
										31
									
								
								evaluate/train_dir/convert.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								evaluate/train_dir/convert.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,31 @@ | |||||||
|  | import ujson | ||||||
|  | def transform_conversation_data(raw_data): | ||||||
|  |     try: | ||||||
|  |         instruction = raw_data.get("conversation", "")[0]['system'] + "\n\n对话:" | ||||||
|  | 
 | ||||||
|  |         conversation = raw_data.get("conversation", []) | ||||||
|  |         for i, dialog in enumerate(conversation): | ||||||
|  |             instruction += "\n来访者:" + dialog["input"] | ||||||
|  | 
 | ||||||
|  |             if i < len(conversation) - 1: | ||||||
|  |                 instruction += "\n医生:" + dialog["output"] | ||||||
|  | 
 | ||||||
|  |         response = conversation[-1]["output"] if conversation else "" | ||||||
|  | 
 | ||||||
|  |         instruction += "\n医生:" | ||||||
|  | 
 | ||||||
|  |         return {"instruction": instruction, "output": response} | ||||||
|  |      | ||||||
|  |     except Exception as e: | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | with open(f'./train_dir/data.json', 'r', encoding='utf-8') as f1: | ||||||
|  |     data = ujson.load(f1) | ||||||
|  | with open(f'./train_dir/converted.json', 'w', encoding='utf-8') as f: | ||||||
|  |     for j, item in enumerate(data): | ||||||
|  |         temp=transform_conversation_data(item) | ||||||
|  |         if temp: | ||||||
|  |             transformed_data =ujson.dumps(temp, ensure_ascii=False) | ||||||
|  |             f.write(transformed_data+'\n') | ||||||
|  |     print('********') | ||||||
							
								
								
									
										1596
									
								
								evaluate/train_dir/converted.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1596
									
								
								evaluate/train_dir/converted.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										28282
									
								
								evaluate/train_dir/data.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28282
									
								
								evaluate/train_dir/data.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
		Reference in New Issue
	
	Block a user
	 ZeyuBa
						ZeyuBa