154 lines
4.6 KiB
Python
154 lines
4.6 KiB
Python
# -*- coding: utf-8 -*-
|
||
# @Time : 2024/10/22
|
||
# @Author : 黄子寒
|
||
# @File : generate_qa_with_multiple_pairs.py
|
||
# @Project : EmoLLM
|
||
|
||
import os
|
||
import re
|
||
from tqdm import tqdm
|
||
import SparkApi
|
||
import json
|
||
|
||
|
||
appid = "f0f73de5"
|
||
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
|
||
api_key = "5773f6f95563708de994d17b7ea5d414"
|
||
|
||
# Spark服务地址及版本
|
||
domain = "4.0Ultra"
|
||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||
|
||
# 准备存储清洗后的文本
|
||
text_data = []
|
||
|
||
# 断点文件,用于存储上次处理的段落索引
|
||
checkpoint_file = "output/progress_checkpoint.txt"
|
||
|
||
# 加载处理好的文本文件
|
||
with open("../processPDF/cleaned_data.txt", "r", encoding="utf-8") as f:
|
||
cleaned_text = f.read()
|
||
|
||
|
||
# 自定义分割函数,按最大300字以内的句子段落
|
||
def split_text_to_sentences(text, max_length=300):
|
||
sentences = re.split('(?<=。)', text)
|
||
grouped_sentences = []
|
||
current_group = ""
|
||
|
||
for sentence in sentences:
|
||
if len(current_group) + len(sentence) <= max_length:
|
||
current_group += sentence
|
||
else:
|
||
grouped_sentences.append(current_group.strip())
|
||
current_group = sentence
|
||
|
||
if current_group:
|
||
grouped_sentences.append(current_group.strip())
|
||
|
||
return grouped_sentences
|
||
|
||
|
||
# 加载断点进度
|
||
def load_checkpoint():
|
||
if os.path.exists(checkpoint_file):
|
||
with open(checkpoint_file, 'r') as f:
|
||
return int(f.read().strip()) # 返回已处理的段落索引
|
||
return 0 # 没有断点则从0开始
|
||
|
||
|
||
# 保存断点进度
|
||
def save_checkpoint(index):
|
||
with open(checkpoint_file, 'w') as f:
|
||
f.write(str(index))
|
||
|
||
|
||
# 将文本按要求的长度进行分割
|
||
paragraphs = split_text_to_sentences(cleaned_text, 300)
|
||
|
||
|
||
# 构建 LLM 生成 input 和 output 的详细 prompt,允许模型生成多个问答对
|
||
def create_prompt(content):
|
||
prompt = (
|
||
f"你是一位油橄榄栽培专家。"
|
||
f"根据以下内容生成一个或多个问题和回答对,请保证语句通顺有逻辑,同时忽略所有内容中和图示相关的内容:\n\n"
|
||
f"内容:{content}\n\n"
|
||
f"请以如下格式生成输出:\n"
|
||
f"问题1:<在这里生成第一个问题>\n"
|
||
f"回答1:<在这里生成第一个回答>\n"
|
||
f"问题2:<在这里生成第二个问题(如有)>\n"
|
||
f"回答2:<在这里生成第二个回答(如有)>\n"
|
||
f"..."
|
||
)
|
||
return prompt
|
||
|
||
|
||
# 解析返回的问答对,处理多个问答对的情况
|
||
def parse_multiple_qa(answer_text):
|
||
qa_pairs = []
|
||
# 通过正则表达式找到所有的问答对
|
||
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
||
matches = pattern.findall(answer_text)
|
||
|
||
for match in matches:
|
||
question = match[0].strip()
|
||
answer = match[1].strip()
|
||
qa_pairs.append({"input": question, "output": answer})
|
||
|
||
return qa_pairs
|
||
|
||
|
||
# 迭代限制,防止API额度过大
|
||
def checklen(text):
|
||
while len(text) > 80000:
|
||
del text[0]
|
||
return text
|
||
|
||
|
||
if __name__ == '__main__':
|
||
text_data.clear()
|
||
file_name = 'output/train_optimized_multiple.jsonl'
|
||
conversations = []
|
||
|
||
# 加载上次的进度
|
||
start_index = load_checkpoint()
|
||
|
||
# 从断点开始继续生成问答对
|
||
# 从断点开始继续生成问答对
|
||
for i in tqdm(range(start_index, len(paragraphs))): # 处理所有剩余的段落
|
||
content = paragraphs[i].strip() # 去除段落前后的空格
|
||
print("====================\ncontent:", content, "\n==================\n")
|
||
if len(content) == 0:
|
||
continue
|
||
|
||
# 构建 LLM 的 prompt
|
||
prompt = create_prompt(content)
|
||
question = checklen([{"role": "user", "content": prompt}])
|
||
|
||
# 调用 LLM 生成问答对
|
||
SparkApi.answer = "" # 清空之前的回答
|
||
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question) # 调用API获取回答
|
||
|
||
# 将生成的文本分割为问题和回答
|
||
answer_text = SparkApi.answer.strip()
|
||
|
||
# 解析多个问答对
|
||
qa_pairs = parse_multiple_qa(answer_text)
|
||
|
||
for qa_pair in qa_pairs:
|
||
conversation = {
|
||
"input": qa_pair['input'],
|
||
"output": qa_pair['output']
|
||
}
|
||
|
||
# 将对话数据添加到文件中
|
||
with open(file_name, 'a', encoding='utf-8') as file:
|
||
json.dump(conversation, file, ensure_ascii=False)
|
||
file.write("\n")
|
||
|
||
# 每处理完一个段落,保存当前的进度索引
|
||
save_checkpoint(i)
|
||
|
||
print(f"已生成 {file_name} 文件,包含问答对。")
|
||
|