OliveSensorAPI/IOTLLM/generate_data/EC_process/extend_QA.py

117 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import json
import os
import re
from tqdm import tqdm
import SparkApi
# 输入文件路径
input_file = 'output/train_expanded.jsonl'
# 输出文件路径
output_file = 'output/train_expanded_2.jsonl'
# 断点文件路径
checkpoint_file = 'output/e2_progress_checkpoint.txt'
# 调用API生成问答对
def generate_qa_via_api(content):
appid = "48d04aae"
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
domain = "4.0Ultra"
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
prompt = (
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
f"1. 根据给定内容生成**三个**相关的问题和回答。\n"
f"2. 你可以简化问题、提取具体要素进行提问,或扩展内容生成额外的相关问题。\n"
f"3. **问题必须简洁明了**,并涵盖内容中的关键信息。\n"
f"4. 每个回答应该准确且**不超过50字**,同时**不少于20字**,以保证内容的简洁和有用性。\n"
f"5. 仅围绕油橄榄栽培的相关内容生成问答对,忽略其他无关信息。\n\n"
f"以下是给定内容:\n\n"
f"内容:{content}\n\n"
f"请按如下格式生成输出:\n"
f"问题1<生成第一个问题>\n"
f"回答1<生成第一个回答>\n"
f"问题2<生成第二个问题>\n"
f"回答2<生成第二个回答>\n"
f"问题3<生成第三个问题>\n"
f"回答3<生成第三个回答>\n\n"
f"请确保每个问题和回答都保持与内容的紧密相关性,并保持专业性。"
)
question = [{"role": "user", "content": prompt}]
SparkApi.answer = ""
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
return SparkApi.answer.strip()
# 加载断点进度
def load_checkpoint():
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r') as f:
return int(f.read().strip()) # 返回已处理的行索引
return 0 # 没有断点则从0开始
# 保存断点进度
def save_checkpoint(index):
with open(checkpoint_file, 'w') as f:
f.write(str(index))
# 解析返回的问答对,处理多个问答对的情况
def parse_multiple_qa(answer_text):
qa_pairs = []
# 通过正则表达式找到所有的问答对
pattern = re.compile(r"问题\d+(.*?)回答\d+(.*?)(问题|$)", re.S)
matches = pattern.findall(answer_text)
for match in matches:
question = match[0].strip()
answer = match[1].strip()
qa_pairs.append({"input": question, "output": answer})
return qa_pairs
if __name__ == '__main__':
# 加载原始数据集
with open(input_file, 'r', encoding='utf-8') as f:
text_data = [json.loads(line) for line in f]
# 加载断点进度
start_index = load_checkpoint()
# 从断点开始继续生成问答对
with open(output_file, 'a', encoding='utf-8') as f:
for i in tqdm(range(start_index, len(text_data))):
item = text_data[i]
input_content = item['input']
try:
# 使用API生成新的问答对
api_generated_qa = generate_qa_via_api(input_content)
# 解析API生成的问答对并添加到数据集
qa_pairs = parse_multiple_qa(api_generated_qa)
expanded_data = [{"input": qa_pair['input'], "output": qa_pair['output']} for qa_pair in qa_pairs]
# 保存生成的问答对
for qa in expanded_data:
json.dump(qa, f, ensure_ascii=False)
f.write('\n')
# 保存当前的进度索引
save_checkpoint(i)
except Exception as e:
print(f"Error processing item {i}: {e}")
# 跳过当前条目继续处理
save_checkpoint(i)
continue
print(f"已生成 {output_file} 文件,包含扩展的问答对。")