117 lines
4.2 KiB
Python
117 lines
4.2 KiB
Python
# -*- coding: utf-8 -*-
|
||
import json
|
||
import os
|
||
import re
|
||
from tqdm import tqdm
|
||
|
||
import SparkApi
|
||
|
||
# 输入文件路径
|
||
input_file = 'output/train_expanded.jsonl'
|
||
# 输出文件路径
|
||
output_file = 'output/train_expanded_2.jsonl'
|
||
# 断点文件路径
|
||
checkpoint_file = 'output/e2_progress_checkpoint.txt'
|
||
|
||
|
||
# 调用API生成问答对
|
||
def generate_qa_via_api(content):
|
||
appid = "48d04aae"
|
||
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
||
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
||
domain = "4.0Ultra"
|
||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||
|
||
prompt = (
|
||
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
||
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
||
f"1. 根据给定内容生成**三个**相关的问题和回答。\n"
|
||
f"2. 你可以简化问题、提取具体要素进行提问,或扩展内容生成额外的相关问题。\n"
|
||
f"3. **问题必须简洁明了**,并涵盖内容中的关键信息。\n"
|
||
f"4. 每个回答应该准确且**不超过50字**,同时**不少于20字**,以保证内容的简洁和有用性。\n"
|
||
f"5. 仅围绕油橄榄栽培的相关内容生成问答对,忽略其他无关信息。\n\n"
|
||
f"以下是给定内容:\n\n"
|
||
f"内容:{content}\n\n"
|
||
f"请按如下格式生成输出:\n"
|
||
f"问题1:<生成第一个问题>\n"
|
||
f"回答1:<生成第一个回答>\n"
|
||
f"问题2:<生成第二个问题>\n"
|
||
f"回答2:<生成第二个回答>\n"
|
||
f"问题3:<生成第三个问题>\n"
|
||
f"回答3:<生成第三个回答>\n\n"
|
||
f"请确保每个问题和回答都保持与内容的紧密相关性,并保持专业性。"
|
||
)
|
||
|
||
question = [{"role": "user", "content": prompt}]
|
||
SparkApi.answer = ""
|
||
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
|
||
return SparkApi.answer.strip()
|
||
|
||
|
||
# 加载断点进度
|
||
def load_checkpoint():
|
||
if os.path.exists(checkpoint_file):
|
||
with open(checkpoint_file, 'r') as f:
|
||
return int(f.read().strip()) # 返回已处理的行索引
|
||
return 0 # 没有断点则从0开始
|
||
|
||
|
||
# 保存断点进度
|
||
def save_checkpoint(index):
|
||
with open(checkpoint_file, 'w') as f:
|
||
f.write(str(index))
|
||
|
||
|
||
# 解析返回的问答对,处理多个问答对的情况
|
||
def parse_multiple_qa(answer_text):
|
||
qa_pairs = []
|
||
# 通过正则表达式找到所有的问答对
|
||
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
||
matches = pattern.findall(answer_text)
|
||
|
||
for match in matches:
|
||
question = match[0].strip()
|
||
answer = match[1].strip()
|
||
qa_pairs.append({"input": question, "output": answer})
|
||
|
||
return qa_pairs
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 加载原始数据集
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
text_data = [json.loads(line) for line in f]
|
||
|
||
# 加载断点进度
|
||
start_index = load_checkpoint()
|
||
|
||
# 从断点开始继续生成问答对
|
||
with open(output_file, 'a', encoding='utf-8') as f:
|
||
for i in tqdm(range(start_index, len(text_data))):
|
||
item = text_data[i]
|
||
input_content = item['input']
|
||
|
||
try:
|
||
# 使用API生成新的问答对
|
||
api_generated_qa = generate_qa_via_api(input_content)
|
||
|
||
# 解析API生成的问答对并添加到数据集
|
||
qa_pairs = parse_multiple_qa(api_generated_qa)
|
||
expanded_data = [{"input": qa_pair['input'], "output": qa_pair['output']} for qa_pair in qa_pairs]
|
||
|
||
# 保存生成的问答对
|
||
for qa in expanded_data:
|
||
json.dump(qa, f, ensure_ascii=False)
|
||
f.write('\n')
|
||
|
||
# 保存当前的进度索引
|
||
save_checkpoint(i)
|
||
|
||
except Exception as e:
|
||
print(f"Error processing item {i}: {e}")
|
||
# 跳过当前条目继续处理
|
||
save_checkpoint(i)
|
||
continue
|
||
|
||
print(f"已生成 {output_file} 文件,包含扩展的问答对。")
|