# -*- coding: utf-8 -*- # @Time : 2024/10/22 # @Author : 黄子寒 # @File : generate_qa_with_multiple_pairs.py # @Project : EmoLLM import os import re from tqdm import tqdm import SparkApi import json appid = "f0f73de5" api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk" api_key = "5773f6f95563708de994d17b7ea5d414" # Spark服务地址及版本 domain = "4.0Ultra" Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat" # 准备存储清洗后的文本 text_data = [] # 断点文件,用于存储上次处理的段落索引 checkpoint_file = "output/progress_checkpoint.txt" # 加载处理好的文本文件 with open("../processPDF/cleaned_data.txt", "r", encoding="utf-8") as f: cleaned_text = f.read() # 自定义分割函数,按最大300字以内的句子段落 def split_text_to_sentences(text, max_length=300): sentences = re.split('(?<=。)', text) grouped_sentences = [] current_group = "" for sentence in sentences: if len(current_group) + len(sentence) <= max_length: current_group += sentence else: grouped_sentences.append(current_group.strip()) current_group = sentence if current_group: grouped_sentences.append(current_group.strip()) return grouped_sentences # 加载断点进度 def load_checkpoint(): if os.path.exists(checkpoint_file): with open(checkpoint_file, 'r') as f: return int(f.read().strip()) # 返回已处理的段落索引 return 0 # 没有断点则从0开始 # 保存断点进度 def save_checkpoint(index): with open(checkpoint_file, 'w') as f: f.write(str(index)) # 将文本按要求的长度进行分割 paragraphs = split_text_to_sentences(cleaned_text, 300) # 构建 LLM 生成 input 和 output 的详细 prompt,允许模型生成多个问答对 def create_prompt(content): prompt = ( f"你是一位油橄榄栽培专家。" f"根据以下内容生成一个或多个问题和回答对,请保证语句通顺有逻辑,同时忽略所有内容中和图示相关的内容:\n\n" f"内容:{content}\n\n" f"请以如下格式生成输出:\n" f"问题1:<在这里生成第一个问题>\n" f"回答1:<在这里生成第一个回答>\n" f"问题2:<在这里生成第二个问题(如有)>\n" f"回答2:<在这里生成第二个回答(如有)>\n" f"..." ) return prompt # 解析返回的问答对,处理多个问答对的情况 def parse_multiple_qa(answer_text): qa_pairs = [] # 通过正则表达式找到所有的问答对 pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S) matches = pattern.findall(answer_text) for match in matches: question = match[0].strip() answer = match[1].strip() qa_pairs.append({"input": question, "output": answer}) return qa_pairs # 迭代限制,防止API额度过大 def checklen(text): while len(text) > 80000: del text[0] return text if __name__ == '__main__': text_data.clear() file_name = 'output/train_optimized_multiple.jsonl' conversations = [] # 加载上次的进度 start_index = load_checkpoint() # 从断点开始继续生成问答对 # 从断点开始继续生成问答对 for i in tqdm(range(start_index, len(paragraphs))): # 处理所有剩余的段落 content = paragraphs[i].strip() # 去除段落前后的空格 print("====================\ncontent:", content, "\n==================\n") if len(content) == 0: continue # 构建 LLM 的 prompt prompt = create_prompt(content) question = checklen([{"role": "user", "content": prompt}]) # 调用 LLM 生成问答对 SparkApi.answer = "" # 清空之前的回答 SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question) # 调用API获取回答 # 将生成的文本分割为问题和回答 answer_text = SparkApi.answer.strip() # 解析多个问答对 qa_pairs = parse_multiple_qa(answer_text) for qa_pair in qa_pairs: conversation = { "input": qa_pair['input'], "output": qa_pair['output'] } # 将对话数据添加到文件中 with open(file_name, 'a', encoding='utf-8') as file: json.dump(conversation, file, ensure_ascii=False) file.write("\n") # 每处理完一个段落,保存当前的进度索引 save_checkpoint(i) print(f"已生成 {file_name} 文件,包含问答对。")