2024-11-11 17:32:36 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
# @Time : 2024/10/22
|
|
|
|
|
# @Author : 黄子寒
|
|
|
|
|
# @File : generate_qa_with_multiple_pairs.py
|
|
|
|
|
# @Project : EmoLLM
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
import SparkApi
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
appid = "f0f73de5"
|
|
|
|
|
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
|
|
|
|
|
api_key = "5773f6f95563708de994d17b7ea5d414"
|
|
|
|
|
|
|
|
|
|
# Spark服务地址及版本
|
|
|
|
|
domain = "4.0Ultra"
|
|
|
|
|
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
|
|
|
|
|
|
|
|
|
# 准备存储清洗后的文本
|
|
|
|
|
text_data = []
|
|
|
|
|
|
|
|
|
|
# 断点文件,用于存储上次处理的段落索引
|
|
|
|
|
checkpoint_file = "output/progress_checkpoint.txt"
|
|
|
|
|
|
|
|
|
|
# 加载处理好的文本文件
|
|
|
|
|
with open("../processPDF/cleaned_data.txt", "r", encoding="utf-8") as f:
|
|
|
|
|
cleaned_text = f.read()
|
|
|
|
|
|
|
|
|
|
|
2024-12-10 23:37:45 +08:00
|
|
|
|
# 自定义分割函数,按最大300字以内的句子段落
|
2024-11-11 17:32:36 +08:00
|
|
|
|
def split_text_to_sentences(text, max_length=300):
|
|
|
|
|
sentences = re.split('(?<=。)', text)
|
|
|
|
|
grouped_sentences = []
|
|
|
|
|
current_group = ""
|
|
|
|
|
|
|
|
|
|
for sentence in sentences:
|
|
|
|
|
if len(current_group) + len(sentence) <= max_length:
|
|
|
|
|
current_group += sentence
|
|
|
|
|
else:
|
|
|
|
|
grouped_sentences.append(current_group.strip())
|
|
|
|
|
current_group = sentence
|
|
|
|
|
|
|
|
|
|
if current_group:
|
|
|
|
|
grouped_sentences.append(current_group.strip())
|
|
|
|
|
|
|
|
|
|
return grouped_sentences
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 加载断点进度
|
|
|
|
|
def load_checkpoint():
|
|
|
|
|
if os.path.exists(checkpoint_file):
|
|
|
|
|
with open(checkpoint_file, 'r') as f:
|
|
|
|
|
return int(f.read().strip()) # 返回已处理的段落索引
|
|
|
|
|
return 0 # 没有断点则从0开始
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 保存断点进度
|
|
|
|
|
def save_checkpoint(index):
|
|
|
|
|
with open(checkpoint_file, 'w') as f:
|
|
|
|
|
f.write(str(index))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 将文本按要求的长度进行分割
|
|
|
|
|
paragraphs = split_text_to_sentences(cleaned_text, 300)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 构建 LLM 生成 input 和 output 的详细 prompt,允许模型生成多个问答对
|
|
|
|
|
def create_prompt(content):
|
|
|
|
|
prompt = (
|
|
|
|
|
f"你是一位油橄榄栽培专家。"
|
|
|
|
|
f"根据以下内容生成一个或多个问题和回答对,请保证语句通顺有逻辑,同时忽略所有内容中和图示相关的内容:\n\n"
|
|
|
|
|
f"内容:{content}\n\n"
|
|
|
|
|
f"请以如下格式生成输出:\n"
|
|
|
|
|
f"问题1:<在这里生成第一个问题>\n"
|
|
|
|
|
f"回答1:<在这里生成第一个回答>\n"
|
|
|
|
|
f"问题2:<在这里生成第二个问题(如有)>\n"
|
|
|
|
|
f"回答2:<在这里生成第二个回答(如有)>\n"
|
|
|
|
|
f"..."
|
|
|
|
|
)
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 解析返回的问答对,处理多个问答对的情况
|
|
|
|
|
def parse_multiple_qa(answer_text):
|
|
|
|
|
qa_pairs = []
|
|
|
|
|
# 通过正则表达式找到所有的问答对
|
|
|
|
|
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
|
|
|
|
matches = pattern.findall(answer_text)
|
|
|
|
|
|
|
|
|
|
for match in matches:
|
|
|
|
|
question = match[0].strip()
|
|
|
|
|
answer = match[1].strip()
|
|
|
|
|
qa_pairs.append({"input": question, "output": answer})
|
|
|
|
|
|
|
|
|
|
return qa_pairs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 迭代限制,防止API额度过大
|
|
|
|
|
def checklen(text):
|
2024-12-10 23:37:45 +08:00
|
|
|
|
while len(text) > 80000:
|
2024-11-11 17:32:36 +08:00
|
|
|
|
del text[0]
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
text_data.clear()
|
|
|
|
|
file_name = 'output/train_optimized_multiple.jsonl'
|
|
|
|
|
conversations = []
|
|
|
|
|
|
|
|
|
|
# 加载上次的进度
|
|
|
|
|
start_index = load_checkpoint()
|
|
|
|
|
|
|
|
|
|
# 从断点开始继续生成问答对
|
|
|
|
|
# 从断点开始继续生成问答对
|
|
|
|
|
for i in tqdm(range(start_index, len(paragraphs))): # 处理所有剩余的段落
|
|
|
|
|
content = paragraphs[i].strip() # 去除段落前后的空格
|
|
|
|
|
print("====================\ncontent:", content, "\n==================\n")
|
|
|
|
|
if len(content) == 0:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 构建 LLM 的 prompt
|
|
|
|
|
prompt = create_prompt(content)
|
|
|
|
|
question = checklen([{"role": "user", "content": prompt}])
|
|
|
|
|
|
|
|
|
|
# 调用 LLM 生成问答对
|
|
|
|
|
SparkApi.answer = "" # 清空之前的回答
|
|
|
|
|
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question) # 调用API获取回答
|
|
|
|
|
|
|
|
|
|
# 将生成的文本分割为问题和回答
|
|
|
|
|
answer_text = SparkApi.answer.strip()
|
|
|
|
|
|
|
|
|
|
# 解析多个问答对
|
|
|
|
|
qa_pairs = parse_multiple_qa(answer_text)
|
|
|
|
|
|
|
|
|
|
for qa_pair in qa_pairs:
|
|
|
|
|
conversation = {
|
|
|
|
|
"input": qa_pair['input'],
|
|
|
|
|
"output": qa_pair['output']
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 将对话数据添加到文件中
|
|
|
|
|
with open(file_name, 'a', encoding='utf-8') as file:
|
|
|
|
|
json.dump(conversation, file, ensure_ascii=False)
|
|
|
|
|
file.write("\n")
|
|
|
|
|
|
|
|
|
|
# 每处理完一个段落,保存当前的进度索引
|
|
|
|
|
save_checkpoint(i)
|
|
|
|
|
|
|
|
|
|
print(f"已生成 {file_name} 文件,包含问答对。")
|
|
|
|
|
|