OliveSensorAPI/IOTLLM/generate_data/EC_process/extend_QA.py

117 lines
4.2 KiB
Python
Raw Normal View History

2024-11-11 17:32:36 +08:00
# -*- coding: utf-8 -*-
import json
import os
import re
from tqdm import tqdm
import SparkApi
# 输入文件路径
input_file = 'output/train_expanded.jsonl'
# 输出文件路径
output_file = 'output/train_expanded_2.jsonl'
# 断点文件路径
checkpoint_file = 'output/e2_progress_checkpoint.txt'
# 调用API生成问答对
def generate_qa_via_api(content):
appid = "48d04aae"
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
domain = "4.0Ultra"
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
prompt = (
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
f"1. 根据给定内容生成**三个**相关的问题和回答。\n"
f"2. 你可以简化问题、提取具体要素进行提问,或扩展内容生成额外的相关问题。\n"
f"3. **问题必须简洁明了**,并涵盖内容中的关键信息。\n"
f"4. 每个回答应该准确且**不超过50字**,同时**不少于20字**,以保证内容的简洁和有用性。\n"
f"5. 仅围绕油橄榄栽培的相关内容生成问答对,忽略其他无关信息。\n\n"
f"以下是给定内容:\n\n"
f"内容:{content}\n\n"
f"请按如下格式生成输出:\n"
f"问题1<生成第一个问题>\n"
f"回答1<生成第一个回答>\n"
f"问题2<生成第二个问题>\n"
f"回答2<生成第二个回答>\n"
f"问题3<生成第三个问题>\n"
f"回答3<生成第三个回答>\n\n"
f"请确保每个问题和回答都保持与内容的紧密相关性,并保持专业性。"
)
question = [{"role": "user", "content": prompt}]
SparkApi.answer = ""
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
return SparkApi.answer.strip()
# 加载断点进度
def load_checkpoint():
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r') as f:
return int(f.read().strip()) # 返回已处理的行索引
return 0 # 没有断点则从0开始
# 保存断点进度
def save_checkpoint(index):
with open(checkpoint_file, 'w') as f:
f.write(str(index))
# 解析返回的问答对,处理多个问答对的情况
def parse_multiple_qa(answer_text):
qa_pairs = []
# 通过正则表达式找到所有的问答对
pattern = re.compile(r"问题\d+(.*?)回答\d+(.*?)(问题|$)", re.S)
matches = pattern.findall(answer_text)
for match in matches:
question = match[0].strip()
answer = match[1].strip()
qa_pairs.append({"input": question, "output": answer})
return qa_pairs
if __name__ == '__main__':
# 加载原始数据集
with open(input_file, 'r', encoding='utf-8') as f:
text_data = [json.loads(line) for line in f]
# 加载断点进度
start_index = load_checkpoint()
# 从断点开始继续生成问答对
with open(output_file, 'a', encoding='utf-8') as f:
for i in tqdm(range(start_index, len(text_data))):
item = text_data[i]
input_content = item['input']
try:
# 使用API生成新的问答对
api_generated_qa = generate_qa_via_api(input_content)
# 解析API生成的问答对并添加到数据集
qa_pairs = parse_multiple_qa(api_generated_qa)
expanded_data = [{"input": qa_pair['input'], "output": qa_pair['output']} for qa_pair in qa_pairs]
# 保存生成的问答对
for qa in expanded_data:
json.dump(qa, f, ensure_ascii=False)
f.write('\n')
# 保存当前的进度索引
save_checkpoint(i)
except Exception as e:
print(f"Error processing item {i}: {e}")
# 跳过当前条目继续处理
save_checkpoint(i)
continue
print(f"已生成 {output_file} 文件,包含扩展的问答对。")