85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
|
# -*- coding: utf-8 -*-
|
|||
|
import json
|
|||
|
import os
|
|||
|
from tqdm import tqdm
|
|||
|
import SparkApi
|
|||
|
|
|||
|
# 输入文件路径
|
|||
|
input_file = 'output/train_expanded.jsonl'
|
|||
|
# 断点文件路径
|
|||
|
checkpoint_file = 'output/expand_checkpoint.txt'
|
|||
|
# 临时文件路径
|
|||
|
temp_file = 'output/tmp_train_expanded.jsonl'
|
|||
|
|
|||
|
|
|||
|
# 调用API生成回答
|
|||
|
def generate_answer_via_api(question):
|
|||
|
appid = "48d04aae"
|
|||
|
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
|||
|
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
|||
|
domain = "4.0Ultra"
|
|||
|
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
|||
|
|
|||
|
prompt = (
|
|||
|
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
|||
|
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
|||
|
f"每个回答应该准确且不超过50字,同时不少于20字,以保证内容的简洁和有用性。\n"
|
|||
|
f"问题:{question}\n\n"
|
|||
|
f"请生成一个详细回答。"
|
|||
|
)
|
|||
|
|
|||
|
question_data = [{"role": "user", "content": prompt}]
|
|||
|
SparkApi.answer = ""
|
|||
|
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data)
|
|||
|
return SparkApi.answer.strip()
|
|||
|
|
|||
|
|
|||
|
# 加载断点进度
|
|||
|
def load_checkpoint():
|
|||
|
if os.path.exists(checkpoint_file):
|
|||
|
with open(checkpoint_file, 'r') as f:
|
|||
|
return int(f.read().strip()) # 返回已处理的行索引
|
|||
|
return 0 # 没有断点则从0开始
|
|||
|
|
|||
|
|
|||
|
# 保存断点进度
|
|||
|
def save_checkpoint(index):
|
|||
|
with open(checkpoint_file, 'w') as f:
|
|||
|
f.write(str(index))
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
# 加载断点进度
|
|||
|
start_index = load_checkpoint()
|
|||
|
|
|||
|
with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f:
|
|||
|
for i, line in enumerate(tqdm(f)):
|
|||
|
item = json.loads(line)
|
|||
|
|
|||
|
# 从断点开始处理
|
|||
|
if i >= start_index:
|
|||
|
input_content = item['input']
|
|||
|
output_content = item['output']
|
|||
|
|
|||
|
# # 检查是否是未提供回答的问答对
|
|||
|
# if "未给" in output_content:
|
|||
|
# # 使用API生成新的回答
|
|||
|
# new_answer = generate_answer_via_api(input_content)
|
|||
|
# item['output'] = new_answer
|
|||
|
|
|||
|
if len(output_content)<11:
|
|||
|
# 使用API生成新的回答
|
|||
|
new_answer = generate_answer_via_api(input_content)
|
|||
|
item['output'] = new_answer
|
|||
|
|
|||
|
# 保存当前的进度索引
|
|||
|
save_checkpoint(i)
|
|||
|
|
|||
|
# 写入更新内容到临时文件
|
|||
|
json.dump(item, temp_f, ensure_ascii=False)
|
|||
|
temp_f.write('\n')
|
|||
|
|
|||
|
# 替换原始文件
|
|||
|
os.replace(temp_file, input_file)
|
|||
|
print(f"已更新 {input_file} 文件,包含重新生成的回答。")
|