OliveSensorAPI/IOTLLM/generate_data/EC_process/gen_QA.py

154 lines
4.6 KiB
Python
Raw Normal View History

2024-11-11 17:32:36 +08:00
# -*- coding: utf-8 -*-
# @Time : 2024/10/22
# @Author : 黄子寒
# @File : generate_qa_with_multiple_pairs.py
# @Project : EmoLLM
import os
import re
from tqdm import tqdm
import SparkApi
import json
appid = "f0f73de5"
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
api_key = "5773f6f95563708de994d17b7ea5d414"
# Spark服务地址及版本
domain = "4.0Ultra"
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
# 准备存储清洗后的文本
text_data = []
# 断点文件,用于存储上次处理的段落索引
checkpoint_file = "output/progress_checkpoint.txt"
# 加载处理好的文本文件
with open("../processPDF/cleaned_data.txt", "r", encoding="utf-8") as f:
cleaned_text = f.read()
2024-12-10 23:37:45 +08:00
# 自定义分割函数按最大300字以内的句子段落
2024-11-11 17:32:36 +08:00
def split_text_to_sentences(text, max_length=300):
sentences = re.split('(?<=。)', text)
grouped_sentences = []
current_group = ""
for sentence in sentences:
if len(current_group) + len(sentence) <= max_length:
current_group += sentence
else:
grouped_sentences.append(current_group.strip())
current_group = sentence
if current_group:
grouped_sentences.append(current_group.strip())
return grouped_sentences
# 加载断点进度
def load_checkpoint():
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r') as f:
return int(f.read().strip()) # 返回已处理的段落索引
return 0 # 没有断点则从0开始
# 保存断点进度
def save_checkpoint(index):
with open(checkpoint_file, 'w') as f:
f.write(str(index))
# 将文本按要求的长度进行分割
paragraphs = split_text_to_sentences(cleaned_text, 300)
# 构建 LLM 生成 input 和 output 的详细 prompt允许模型生成多个问答对
def create_prompt(content):
prompt = (
f"你是一位油橄榄栽培专家。"
f"根据以下内容生成一个或多个问题和回答对,请保证语句通顺有逻辑,同时忽略所有内容中和图示相关的内容:\n\n"
f"内容:{content}\n\n"
f"请以如下格式生成输出:\n"
f"问题1<在这里生成第一个问题>\n"
f"回答1<在这里生成第一个回答>\n"
f"问题2<在这里生成第二个问题(如有)>\n"
f"回答2<在这里生成第二个回答(如有)>\n"
f"..."
)
return prompt
# 解析返回的问答对,处理多个问答对的情况
def parse_multiple_qa(answer_text):
qa_pairs = []
# 通过正则表达式找到所有的问答对
pattern = re.compile(r"问题\d+(.*?)回答\d+(.*?)(问题|$)", re.S)
matches = pattern.findall(answer_text)
for match in matches:
question = match[0].strip()
answer = match[1].strip()
qa_pairs.append({"input": question, "output": answer})
return qa_pairs
# 迭代限制防止API额度过大
def checklen(text):
2024-12-10 23:37:45 +08:00
while len(text) > 80000:
2024-11-11 17:32:36 +08:00
del text[0]
return text
if __name__ == '__main__':
text_data.clear()
file_name = 'output/train_optimized_multiple.jsonl'
conversations = []
# 加载上次的进度
start_index = load_checkpoint()
# 从断点开始继续生成问答对
# 从断点开始继续生成问答对
for i in tqdm(range(start_index, len(paragraphs))): # 处理所有剩余的段落
content = paragraphs[i].strip() # 去除段落前后的空格
print("====================\ncontent:", content, "\n==================\n")
if len(content) == 0:
continue
# 构建 LLM 的 prompt
prompt = create_prompt(content)
question = checklen([{"role": "user", "content": prompt}])
# 调用 LLM 生成问答对
SparkApi.answer = "" # 清空之前的回答
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question) # 调用API获取回答
# 将生成的文本分割为问题和回答
answer_text = SparkApi.answer.strip()
# 解析多个问答对
qa_pairs = parse_multiple_qa(answer_text)
for qa_pair in qa_pairs:
conversation = {
"input": qa_pair['input'],
"output": qa_pair['output']
}
# 将对话数据添加到文件中
with open(file_name, 'a', encoding='utf-8') as file:
json.dump(conversation, file, ensure_ascii=False)
file.write("\n")
# 每处理完一个段落,保存当前的进度索引
save_checkpoint(i)
print(f"已生成 {file_name} 文件,包含问答对。")