117 lines
4.2 KiB
Python
117 lines
4.2 KiB
Python
|
# -*- coding: utf-8 -*-
|
|||
|
import json
|
|||
|
import os
|
|||
|
import re
|
|||
|
from tqdm import tqdm
|
|||
|
|
|||
|
import SparkApi
|
|||
|
|
|||
|
# 输入文件路径
|
|||
|
input_file = 'output/train_expanded.jsonl'
|
|||
|
# 输出文件路径
|
|||
|
output_file = 'output/train_expanded_2.jsonl'
|
|||
|
# 断点文件路径
|
|||
|
checkpoint_file = 'output/e2_progress_checkpoint.txt'
|
|||
|
|
|||
|
|
|||
|
# 调用API生成问答对
|
|||
|
def generate_qa_via_api(content):
|
|||
|
appid = "48d04aae"
|
|||
|
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
|||
|
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
|||
|
domain = "4.0Ultra"
|
|||
|
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
|||
|
|
|||
|
prompt = (
|
|||
|
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
|||
|
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
|||
|
f"1. 根据给定内容生成**三个**相关的问题和回答。\n"
|
|||
|
f"2. 你可以简化问题、提取具体要素进行提问,或扩展内容生成额外的相关问题。\n"
|
|||
|
f"3. **问题必须简洁明了**,并涵盖内容中的关键信息。\n"
|
|||
|
f"4. 每个回答应该准确且**不超过50字**,同时**不少于20字**,以保证内容的简洁和有用性。\n"
|
|||
|
f"5. 仅围绕油橄榄栽培的相关内容生成问答对,忽略其他无关信息。\n\n"
|
|||
|
f"以下是给定内容:\n\n"
|
|||
|
f"内容:{content}\n\n"
|
|||
|
f"请按如下格式生成输出:\n"
|
|||
|
f"问题1:<生成第一个问题>\n"
|
|||
|
f"回答1:<生成第一个回答>\n"
|
|||
|
f"问题2:<生成第二个问题>\n"
|
|||
|
f"回答2:<生成第二个回答>\n"
|
|||
|
f"问题3:<生成第三个问题>\n"
|
|||
|
f"回答3:<生成第三个回答>\n\n"
|
|||
|
f"请确保每个问题和回答都保持与内容的紧密相关性,并保持专业性。"
|
|||
|
)
|
|||
|
|
|||
|
question = [{"role": "user", "content": prompt}]
|
|||
|
SparkApi.answer = ""
|
|||
|
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
|
|||
|
return SparkApi.answer.strip()
|
|||
|
|
|||
|
|
|||
|
# 加载断点进度
|
|||
|
def load_checkpoint():
|
|||
|
if os.path.exists(checkpoint_file):
|
|||
|
with open(checkpoint_file, 'r') as f:
|
|||
|
return int(f.read().strip()) # 返回已处理的行索引
|
|||
|
return 0 # 没有断点则从0开始
|
|||
|
|
|||
|
|
|||
|
# 保存断点进度
|
|||
|
def save_checkpoint(index):
|
|||
|
with open(checkpoint_file, 'w') as f:
|
|||
|
f.write(str(index))
|
|||
|
|
|||
|
|
|||
|
# 解析返回的问答对,处理多个问答对的情况
|
|||
|
def parse_multiple_qa(answer_text):
|
|||
|
qa_pairs = []
|
|||
|
# 通过正则表达式找到所有的问答对
|
|||
|
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
|||
|
matches = pattern.findall(answer_text)
|
|||
|
|
|||
|
for match in matches:
|
|||
|
question = match[0].strip()
|
|||
|
answer = match[1].strip()
|
|||
|
qa_pairs.append({"input": question, "output": answer})
|
|||
|
|
|||
|
return qa_pairs
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
# 加载原始数据集
|
|||
|
with open(input_file, 'r', encoding='utf-8') as f:
|
|||
|
text_data = [json.loads(line) for line in f]
|
|||
|
|
|||
|
# 加载断点进度
|
|||
|
start_index = load_checkpoint()
|
|||
|
|
|||
|
# 从断点开始继续生成问答对
|
|||
|
with open(output_file, 'a', encoding='utf-8') as f:
|
|||
|
for i in tqdm(range(start_index, len(text_data))):
|
|||
|
item = text_data[i]
|
|||
|
input_content = item['input']
|
|||
|
|
|||
|
try:
|
|||
|
# 使用API生成新的问答对
|
|||
|
api_generated_qa = generate_qa_via_api(input_content)
|
|||
|
|
|||
|
# 解析API生成的问答对并添加到数据集
|
|||
|
qa_pairs = parse_multiple_qa(api_generated_qa)
|
|||
|
expanded_data = [{"input": qa_pair['input'], "output": qa_pair['output']} for qa_pair in qa_pairs]
|
|||
|
|
|||
|
# 保存生成的问答对
|
|||
|
for qa in expanded_data:
|
|||
|
json.dump(qa, f, ensure_ascii=False)
|
|||
|
f.write('\n')
|
|||
|
|
|||
|
# 保存当前的进度索引
|
|||
|
save_checkpoint(i)
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"Error processing item {i}: {e}")
|
|||
|
# 跳过当前条目继续处理
|
|||
|
save_checkpoint(i)
|
|||
|
continue
|
|||
|
|
|||
|
print(f"已生成 {output_file} 文件,包含扩展的问答对。")
|