OliveSensorAPI/generate_data/EC_process/process_missing_QA.py

85 lines
2.8 KiB
Python
Raw Normal View History

2024-11-11 17:32:36 +08:00
# -*- coding: utf-8 -*-
import json
import os
from tqdm import tqdm
import SparkApi
# 输入文件路径
input_file = 'output/train_expanded.jsonl'
# 断点文件路径
checkpoint_file = 'output/expand_checkpoint.txt'
# 临时文件路径
temp_file = 'output/tmp_train_expanded.jsonl'
# 调用API生成回答
def generate_answer_via_api(question):
appid = "48d04aae"
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
domain = "4.0Ultra"
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
prompt = (
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
f"每个回答应该准确且不超过50字同时不少于20字以保证内容的简洁和有用性。\n"
f"问题:{question}\n\n"
f"请生成一个详细回答。"
)
question_data = [{"role": "user", "content": prompt}]
SparkApi.answer = ""
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data)
return SparkApi.answer.strip()
# 加载断点进度
def load_checkpoint():
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r') as f:
return int(f.read().strip()) # 返回已处理的行索引
return 0 # 没有断点则从0开始
# 保存断点进度
def save_checkpoint(index):
with open(checkpoint_file, 'w') as f:
f.write(str(index))
if __name__ == '__main__':
# 加载断点进度
start_index = load_checkpoint()
with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f:
for i, line in enumerate(tqdm(f)):
item = json.loads(line)
# 从断点开始处理
if i >= start_index:
input_content = item['input']
output_content = item['output']
# # 检查是否是未提供回答的问答对
# if "未给" in output_content:
# # 使用API生成新的回答
# new_answer = generate_answer_via_api(input_content)
# item['output'] = new_answer
if len(output_content)<11:
# 使用API生成新的回答
new_answer = generate_answer_via_api(input_content)
item['output'] = new_answer
# 保存当前的进度索引
save_checkpoint(i)
# 写入更新内容到临时文件
json.dump(item, temp_f, ensure_ascii=False)
temp_f.write('\n')
# 替换原始文件
os.replace(temp_file, input_file)
print(f"已更新 {input_file} 文件,包含重新生成的回答。")