# -*- coding: utf-8 -*- import json import os from tqdm import tqdm import SparkApi # 输入文件路径 input_file = 'output/train_expanded.jsonl' # 断点文件路径 checkpoint_file = 'output/expand_checkpoint.txt' # 临时文件路径 temp_file = 'output/tmp_train_expanded.jsonl' # 调用API生成回答 def generate_answer_via_api(question): appid = "48d04aae" api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz" api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00" domain = "4.0Ultra" Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat" prompt = ( f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。" f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n" f"每个回答应该准确且不超过50字,同时不少于20字,以保证内容的简洁和有用性。\n" f"问题:{question}\n\n" f"请生成一个详细回答。" ) question_data = [{"role": "user", "content": prompt}] SparkApi.answer = "" SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data) return SparkApi.answer.strip() # 加载断点进度 def load_checkpoint(): if os.path.exists(checkpoint_file): with open(checkpoint_file, 'r') as f: return int(f.read().strip()) # 返回已处理的行索引 return 0 # 没有断点则从0开始 # 保存断点进度 def save_checkpoint(index): with open(checkpoint_file, 'w') as f: f.write(str(index)) if __name__ == '__main__': # 加载断点进度 start_index = load_checkpoint() with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f: for i, line in enumerate(tqdm(f)): item = json.loads(line) # 从断点开始处理 if i >= start_index: input_content = item['input'] output_content = item['output'] # # 检查是否是未提供回答的问答对 # if "未给" in output_content: # # 使用API生成新的回答 # new_answer = generate_answer_via_api(input_content) # item['output'] = new_answer if len(output_content)<11: # 使用API生成新的回答 new_answer = generate_answer_via_api(input_content) item['output'] = new_answer # 保存当前的进度索引 save_checkpoint(i) # 写入更新内容到临时文件 json.dump(item, temp_f, ensure_ascii=False) temp_f.write('\n') # 替换原始文件 os.replace(temp_file, input_file) print(f"已更新 {input_file} 文件,包含重新生成的回答。")