85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
# -*- coding: utf-8 -*-
|
||
import json
|
||
import os
|
||
from tqdm import tqdm
|
||
import SparkApi
|
||
|
||
# 输入文件路径
|
||
input_file = 'output/train_expanded.jsonl'
|
||
# 断点文件路径
|
||
checkpoint_file = 'output/expand_checkpoint.txt'
|
||
# 临时文件路径
|
||
temp_file = 'output/tmp_train_expanded.jsonl'
|
||
|
||
|
||
# 调用API生成回答
|
||
def generate_answer_via_api(question):
|
||
appid = "48d04aae"
|
||
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
||
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
||
domain = "4.0Ultra"
|
||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||
|
||
prompt = (
|
||
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
||
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
||
f"每个回答应该准确且不超过50字,同时不少于20字,以保证内容的简洁和有用性。\n"
|
||
f"问题:{question}\n\n"
|
||
f"请生成一个详细回答。"
|
||
)
|
||
|
||
question_data = [{"role": "user", "content": prompt}]
|
||
SparkApi.answer = ""
|
||
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data)
|
||
return SparkApi.answer.strip()
|
||
|
||
|
||
# 加载断点进度
|
||
def load_checkpoint():
|
||
if os.path.exists(checkpoint_file):
|
||
with open(checkpoint_file, 'r') as f:
|
||
return int(f.read().strip()) # 返回已处理的行索引
|
||
return 0 # 没有断点则从0开始
|
||
|
||
|
||
# 保存断点进度
|
||
def save_checkpoint(index):
|
||
with open(checkpoint_file, 'w') as f:
|
||
f.write(str(index))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 加载断点进度
|
||
start_index = load_checkpoint()
|
||
|
||
with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f:
|
||
for i, line in enumerate(tqdm(f)):
|
||
item = json.loads(line)
|
||
|
||
# 从断点开始处理
|
||
if i >= start_index:
|
||
input_content = item['input']
|
||
output_content = item['output']
|
||
|
||
# # 检查是否是未提供回答的问答对
|
||
# if "未给" in output_content:
|
||
# # 使用API生成新的回答
|
||
# new_answer = generate_answer_via_api(input_content)
|
||
# item['output'] = new_answer
|
||
|
||
if len(output_content)<11:
|
||
# 使用API生成新的回答
|
||
new_answer = generate_answer_via_api(input_content)
|
||
item['output'] = new_answer
|
||
|
||
# 保存当前的进度索引
|
||
save_checkpoint(i)
|
||
|
||
# 写入更新内容到临时文件
|
||
json.dump(item, temp_f, ensure_ascii=False)
|
||
temp_f.write('\n')
|
||
|
||
# 替换原始文件
|
||
os.replace(temp_file, input_file)
|
||
print(f"已更新 {input_file} 文件,包含重新生成的回答。")
|