OliveSensorAPI/IOTLLM/generate_data/EC_process/process_missing_QA.py

85 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import json
import os
from tqdm import tqdm
import SparkApi
# 输入文件路径
input_file = 'output/train_expanded.jsonl'
# 断点文件路径
checkpoint_file = 'output/expand_checkpoint.txt'
# 临时文件路径
temp_file = 'output/tmp_train_expanded.jsonl'
# 调用API生成回答
def generate_answer_via_api(question):
appid = "48d04aae"
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
domain = "4.0Ultra"
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
prompt = (
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
f"每个回答应该准确且不超过50字同时不少于20字以保证内容的简洁和有用性。\n"
f"问题:{question}\n\n"
f"请生成一个详细回答。"
)
question_data = [{"role": "user", "content": prompt}]
SparkApi.answer = ""
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data)
return SparkApi.answer.strip()
# 加载断点进度
def load_checkpoint():
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r') as f:
return int(f.read().strip()) # 返回已处理的行索引
return 0 # 没有断点则从0开始
# 保存断点进度
def save_checkpoint(index):
with open(checkpoint_file, 'w') as f:
f.write(str(index))
if __name__ == '__main__':
# 加载断点进度
start_index = load_checkpoint()
with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f:
for i, line in enumerate(tqdm(f)):
item = json.loads(line)
# 从断点开始处理
if i >= start_index:
input_content = item['input']
output_content = item['output']
# # 检查是否是未提供回答的问答对
# if "未给" in output_content:
# # 使用API生成新的回答
# new_answer = generate_answer_via_api(input_content)
# item['output'] = new_answer
if len(output_content)<11:
# 使用API生成新的回答
new_answer = generate_answer_via_api(input_content)
item['output'] = new_answer
# 保存当前的进度索引
save_checkpoint(i)
# 写入更新内容到临时文件
json.dump(item, temp_f, ensure_ascii=False)
temp_f.write('\n')
# 替换原始文件
os.replace(temp_file, input_file)
print(f"已更新 {input_file} 文件,包含重新生成的回答。")