85 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			85 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								# -*- coding: utf-8 -*-
							 | 
						|||
| 
								 | 
							
								import json
							 | 
						|||
| 
								 | 
							
								import os
							 | 
						|||
| 
								 | 
							
								from tqdm import tqdm
							 | 
						|||
| 
								 | 
							
								import SparkApi
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 输入文件路径
							 | 
						|||
| 
								 | 
							
								input_file = 'output/train_expanded.jsonl'
							 | 
						|||
| 
								 | 
							
								# 断点文件路径
							 | 
						|||
| 
								 | 
							
								checkpoint_file = 'output/expand_checkpoint.txt'
							 | 
						|||
| 
								 | 
							
								# 临时文件路径
							 | 
						|||
| 
								 | 
							
								temp_file = 'output/tmp_train_expanded.jsonl'
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 调用API生成回答
							 | 
						|||
| 
								 | 
							
								def generate_answer_via_api(question):
							 | 
						|||
| 
								 | 
							
								    appid = "48d04aae"
							 | 
						|||
| 
								 | 
							
								    api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
							 | 
						|||
| 
								 | 
							
								    api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
							 | 
						|||
| 
								 | 
							
								    domain = "4.0Ultra"
							 | 
						|||
| 
								 | 
							
								    Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								    prompt = (
							 | 
						|||
| 
								 | 
							
								        f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
							 | 
						|||
| 
								 | 
							
								        f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
							 | 
						|||
| 
								 | 
							
								        f"每个回答应该准确且不超过50字,同时不少于20字,以保证内容的简洁和有用性。\n"
							 | 
						|||
| 
								 | 
							
								        f"问题:{question}\n\n"
							 | 
						|||
| 
								 | 
							
								        f"请生成一个详细回答。"
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								    question_data = [{"role": "user", "content": prompt}]
							 | 
						|||
| 
								 | 
							
								    SparkApi.answer = ""
							 | 
						|||
| 
								 | 
							
								    SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data)
							 | 
						|||
| 
								 | 
							
								    return SparkApi.answer.strip()
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 加载断点进度
							 | 
						|||
| 
								 | 
							
								def load_checkpoint():
							 | 
						|||
| 
								 | 
							
								    if os.path.exists(checkpoint_file):
							 | 
						|||
| 
								 | 
							
								        with open(checkpoint_file, 'r') as f:
							 | 
						|||
| 
								 | 
							
								            return int(f.read().strip())  # 返回已处理的行索引
							 | 
						|||
| 
								 | 
							
								    return 0  # 没有断点则从0开始
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 保存断点进度
							 | 
						|||
| 
								 | 
							
								def save_checkpoint(index):
							 | 
						|||
| 
								 | 
							
								    with open(checkpoint_file, 'w') as f:
							 | 
						|||
| 
								 | 
							
								        f.write(str(index))
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								if __name__ == '__main__':
							 | 
						|||
| 
								 | 
							
								    # 加载断点进度
							 | 
						|||
| 
								 | 
							
								    start_index = load_checkpoint()
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								    with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f:
							 | 
						|||
| 
								 | 
							
								        for i, line in enumerate(tqdm(f)):
							 | 
						|||
| 
								 | 
							
								            item = json.loads(line)
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								            # 从断点开始处理
							 | 
						|||
| 
								 | 
							
								            if i >= start_index:
							 | 
						|||
| 
								 | 
							
								                input_content = item['input']
							 | 
						|||
| 
								 | 
							
								                output_content = item['output']
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								                # # 检查是否是未提供回答的问答对
							 | 
						|||
| 
								 | 
							
								                # if "未给" in output_content:
							 | 
						|||
| 
								 | 
							
								                #     # 使用API生成新的回答
							 | 
						|||
| 
								 | 
							
								                #     new_answer = generate_answer_via_api(input_content)
							 | 
						|||
| 
								 | 
							
								                #     item['output'] = new_answer
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								                if len(output_content)<11:
							 | 
						|||
| 
								 | 
							
								                    # 使用API生成新的回答
							 | 
						|||
| 
								 | 
							
								                    new_answer = generate_answer_via_api(input_content)
							 | 
						|||
| 
								 | 
							
								                    item['output'] = new_answer
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								                # 保存当前的进度索引
							 | 
						|||
| 
								 | 
							
								                save_checkpoint(i)
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								            # 写入更新内容到临时文件
							 | 
						|||
| 
								 | 
							
								            json.dump(item, temp_f, ensure_ascii=False)
							 | 
						|||
| 
								 | 
							
								            temp_f.write('\n')
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								    # 替换原始文件
							 | 
						|||
| 
								 | 
							
								    os.replace(temp_file, input_file)
							 | 
						|||
| 
								 | 
							
								    print(f"已更新 {input_file} 文件,包含重新生成的回答。")
							 |