OliveSensorAPI/IOTLLM/generate_data/EC_process/gen_QA.py

154 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
# @Time : 2024/10/22
# @Author : 黄子寒
# @File : generate_qa_with_multiple_pairs.py
# @Project : EmoLLM
import os
import re
from tqdm import tqdm
import SparkApi
import json
appid = "f0f73de5"
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
api_key = "5773f6f95563708de994d17b7ea5d414"
# Spark服务地址及版本
domain = "4.0Ultra"
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
# 准备存储清洗后的文本
text_data = []
# 断点文件,用于存储上次处理的段落索引
checkpoint_file = "output/progress_checkpoint.txt"
# 加载处理好的文本文件
with open("../processPDF/cleaned_data.txt", "r", encoding="utf-8") as f:
cleaned_text = f.read()
# 自定义分割函数按最大300字以内的句子段落
def split_text_to_sentences(text, max_length=300):
sentences = re.split('(?<=。)', text)
grouped_sentences = []
current_group = ""
for sentence in sentences:
if len(current_group) + len(sentence) <= max_length:
current_group += sentence
else:
grouped_sentences.append(current_group.strip())
current_group = sentence
if current_group:
grouped_sentences.append(current_group.strip())
return grouped_sentences
# 加载断点进度
def load_checkpoint():
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r') as f:
return int(f.read().strip()) # 返回已处理的段落索引
return 0 # 没有断点则从0开始
# 保存断点进度
def save_checkpoint(index):
with open(checkpoint_file, 'w') as f:
f.write(str(index))
# 将文本按要求的长度进行分割
paragraphs = split_text_to_sentences(cleaned_text, 300)
# 构建 LLM 生成 input 和 output 的详细 prompt允许模型生成多个问答对
def create_prompt(content):
prompt = (
f"你是一位油橄榄栽培专家。"
f"根据以下内容生成一个或多个问题和回答对,请保证语句通顺有逻辑,同时忽略所有内容中和图示相关的内容:\n\n"
f"内容:{content}\n\n"
f"请以如下格式生成输出:\n"
f"问题1<在这里生成第一个问题>\n"
f"回答1<在这里生成第一个回答>\n"
f"问题2<在这里生成第二个问题(如有)>\n"
f"回答2<在这里生成第二个回答(如有)>\n"
f"..."
)
return prompt
# 解析返回的问答对,处理多个问答对的情况
def parse_multiple_qa(answer_text):
qa_pairs = []
# 通过正则表达式找到所有的问答对
pattern = re.compile(r"问题\d+(.*?)回答\d+(.*?)(问题|$)", re.S)
matches = pattern.findall(answer_text)
for match in matches:
question = match[0].strip()
answer = match[1].strip()
qa_pairs.append({"input": question, "output": answer})
return qa_pairs
# 迭代限制防止API额度过大
def checklen(text):
while len(text) > 80000:
del text[0]
return text
if __name__ == '__main__':
text_data.clear()
file_name = 'output/train_optimized_multiple.jsonl'
conversations = []
# 加载上次的进度
start_index = load_checkpoint()
# 从断点开始继续生成问答对
# 从断点开始继续生成问答对
for i in tqdm(range(start_index, len(paragraphs))): # 处理所有剩余的段落
content = paragraphs[i].strip() # 去除段落前后的空格
print("====================\ncontent:", content, "\n==================\n")
if len(content) == 0:
continue
# 构建 LLM 的 prompt
prompt = create_prompt(content)
question = checklen([{"role": "user", "content": prompt}])
# 调用 LLM 生成问答对
SparkApi.answer = "" # 清空之前的回答
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question) # 调用API获取回答
# 将生成的文本分割为问题和回答
answer_text = SparkApi.answer.strip()
# 解析多个问答对
qa_pairs = parse_multiple_qa(answer_text)
for qa_pair in qa_pairs:
conversation = {
"input": qa_pair['input'],
"output": qa_pair['output']
}
# 将对话数据添加到文件中
with open(file_name, 'a', encoding='utf-8') as file:
json.dump(conversation, file, ensure_ascii=False)
file.write("\n")
# 每处理完一个段落,保存当前的进度索引
save_checkpoint(i)
print(f"已生成 {file_name} 文件,包含问答对。")