Upload QA generation pipeline
This commit is contained in:
		
							parent
							
								
									54ee4010be
								
							
						
					
					
						commit
						57a9db4c5b
					
				
							
								
								
									
										43
									
								
								scripts/qa_generation/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								scripts/qa_generation/README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | ||||
| # QA Generation Pipeline | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| ## 1. 使用方法 | ||||
| 
 | ||||
| 检查 `requirements.txt` 中的依赖是否满足。 | ||||
| 
 | ||||
| 而后,在 `config/config.py` 配置所需的 API KEY,从 `main.py` 启动即可。生成的 QA 对会以 jsonl 的格式存在 `data/generated` 下。 | ||||
| 
 | ||||
| 可以调整 `system_prompt.md`,增强生成的多样性和稳定性。 | ||||
| 
 | ||||
| ### 1.1 API KEY 获取方法 | ||||
| 
 | ||||
| 目前仅包含了 qwen。 | ||||
| 
 | ||||
| #### 1.1.1 Qwen | ||||
| 
 | ||||
| 前往[模型服务灵积-API-KEY管理 (aliyun.com)](https://dashscope.console.aliyun.com/apiKey),点击”创建新的 API-KEY“,将获取的 API KEY 填至 `config/config.py` 中的 `DASHSCOPE_API_KEY` 即可。 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| ## 2. 注意事项 | ||||
| 
 | ||||
| ### 2.1 系统提示 System Prompt | ||||
| 
 | ||||
| 注意,目前的解析方案是基于模型会生成 markdown 包裹的 json 块的前提的,更改 system prompt 时需要保证这一点不变。 | ||||
| 
 | ||||
| ### 2.2 滑动窗口 Sliding Window | ||||
| 
 | ||||
| 滑动窗口的 `window_size` 和 `overlap_size` 都可以在 `util/data_loader.py` 中的 `get_txt_content` 函数中更改。目前是按照句子分割的滑动窗口。 | ||||
| 
 | ||||
| ### 2.3 书本文件格式 Corpus Format | ||||
| 
 | ||||
| 目前仅支持了 txt 格式,可以将清洗好的书籍文本放在 `data` 文件夹下,程序会递归检索该文件夹下的所有 txt 文件。 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| ## TODO | ||||
| 
 | ||||
| 1. 支持更多模型(Gemini、GPT、ChatGLM……) | ||||
| 2. 支持更多文本格式(PDF……) | ||||
| 3. 支持更多切分文本的方式 | ||||
							
								
								
									
										0
									
								
								scripts/qa_generation/config/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								scripts/qa_generation/config/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										28
									
								
								scripts/qa_generation/config/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								scripts/qa_generation/config/config.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,28 @@ | ||||
| import os | ||||
| 
 | ||||
| """ | ||||
| 文件夹路径 | ||||
| """ | ||||
| cur_dir = os.path.dirname(os.path.abspath(__file__))                    # config | ||||
| base_dir = os.path.dirname(cur_dir)                                     # base | ||||
| 
 | ||||
| # model | ||||
| model_dir = os.path.join(base_dir, 'model')                             # model | ||||
| 
 | ||||
| # data | ||||
| data_dir = os.path.join(base_dir, 'data')                               # data | ||||
| result_dir = os.path.join(data_dir, 'generated')                        # result | ||||
| 
 | ||||
| # log | ||||
| log_dir = os.path.join(base_dir, 'log')                                 # log | ||||
| log_file_path = os.path.join(log_dir, 'log.log')                        # file | ||||
| 
 | ||||
| # system prompt | ||||
| system_prompt_file_path = os.path.join(base_dir, 'system_prompt.md')    # system prompt | ||||
| 
 | ||||
| 
 | ||||
| """ | ||||
| 环境变量 | ||||
| """ | ||||
| # api-keys | ||||
| DASHSCOPE_API_KEY = 'sk-xxxxxxxx' | ||||
							
								
								
									
										67
									
								
								scripts/qa_generation/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								scripts/qa_generation/main.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| import os | ||||
| import json | ||||
| from tqdm import tqdm | ||||
| from datetime import datetime | ||||
| 
 | ||||
| from config.config import result_dir | ||||
| from model.qwen import call_qwen_single_turn | ||||
| from util.logger import get_logger | ||||
| from util.data_loader import get_file_list, get_txt_content, capture_qa | ||||
| 
 | ||||
| logger = get_logger() | ||||
| 
 | ||||
| """ | ||||
| 生成 QA 对 | ||||
| model_name: 可调用的模型名称,暂时只实现了 qwen | ||||
| interval: 存储间隔,即每隔多少条存一次文件,过密的间隔会增大 IO 开销 | ||||
| """ | ||||
| def generate_qa( | ||||
|     model_name: str = 'qwen', | ||||
|     interval: int = 1, | ||||
| ): | ||||
|     current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | ||||
|      | ||||
|     if model_name == 'qwen': | ||||
|         model_caller = call_qwen_single_turn | ||||
|     else: | ||||
|         logger.warning('This model is currently not supported and will call the default model - qwen.') | ||||
|         model_caller = call_qwen_single_turn | ||||
|         model_name = 'qwen' | ||||
|      | ||||
|     logger.info(f'The called model is: {model_name}.') | ||||
|     logger.info(f'The storage interval is: {interval}.') | ||||
| 
 | ||||
|     file_list = get_file_list() | ||||
|     storage_counter = 0 | ||||
|     storage_list = [] | ||||
|     for file_name in file_list: | ||||
|         contents = get_txt_content(file_name) | ||||
|         storage_list = [] | ||||
|         storage_jsonl_path = os.path.join(result_dir, f'{current_time}-{file_name}-{model_name}.jsonl') | ||||
|         logger.info(f'The generated QA will be stored in {storage_jsonl_path}.') | ||||
|          | ||||
|         for content in tqdm(contents): | ||||
|             response = model_caller(content) | ||||
|             captured_qa = capture_qa(response) | ||||
|             if captured_qa is None: | ||||
|                 continue | ||||
|              | ||||
|             storage_list.extend(captured_qa) | ||||
|             storage_counter += 1 | ||||
|             if storage_counter % interval == 0: | ||||
|                 storage_counter = 0 | ||||
|                 with open(storage_jsonl_path, 'a', encoding='utf-8') as f: | ||||
|                     for item in storage_list: | ||||
|                         f.write(json.dumps(item, ensure_ascii=False) + '\n') | ||||
|                     storage_list = [] | ||||
|      | ||||
|         # 如果有剩余,存入 | ||||
|         if storage_list: | ||||
|             with open(storage_jsonl_path, 'a', encoding='utf-8') as f: | ||||
|                 for item in storage_list: | ||||
|                     f.write(json.dumps(item, ensure_ascii=False) + '\n') | ||||
|                 storage_list = [] | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     generate_qa() | ||||
							
								
								
									
										0
									
								
								scripts/qa_generation/model/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								scripts/qa_generation/model/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								scripts/qa_generation/model/gemini.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								scripts/qa_generation/model/gemini.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								scripts/qa_generation/model/glm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								scripts/qa_generation/model/glm.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								scripts/qa_generation/model/gpt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								scripts/qa_generation/model/gpt.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										41
									
								
								scripts/qa_generation/model/qwen.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								scripts/qa_generation/model/qwen.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,41 @@ | ||||
| import dashscope | ||||
| from http import HTTPStatus | ||||
| from dashscope import Generation | ||||
| from dashscope.api_entities.dashscope_response import Role | ||||
| 
 | ||||
| from config.config import DASHSCOPE_API_KEY | ||||
| from util.logger import get_logger | ||||
| from util.prompt_loader import load_system_prompt | ||||
| 
 | ||||
| 
 | ||||
| dashscope.api_key = DASHSCOPE_API_KEY | ||||
| 
 | ||||
| logger = get_logger() | ||||
| 
 | ||||
| 
 | ||||
| def call_qwen_single_turn(query: str) -> str: | ||||
|     messages = [ | ||||
|         { | ||||
|             'role': Role.SYSTEM, | ||||
|             'content': load_system_prompt() | ||||
|         }, | ||||
|         { | ||||
|             'role': Role.USER, | ||||
|             'content': query | ||||
|         } | ||||
|     ] | ||||
|     response = Generation.call( | ||||
|         model='qwen-max-1201', | ||||
|         messages=messages, | ||||
|         result_format='message', | ||||
|         stream=False, | ||||
|         incremental_output=False | ||||
|     ) | ||||
|     if response.status_code == HTTPStatus.OK: | ||||
|         return response.output.choices[0]['message']['content'] | ||||
|     else: | ||||
|         logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( | ||||
|             response.request_id, response.status_code, | ||||
|             response.code, response.message | ||||
|         )) | ||||
|         return "" | ||||
							
								
								
									
										3
									
								
								scripts/qa_generation/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								scripts/qa_generation/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,3 @@ | ||||
| dashscope | ||||
| loguru | ||||
| tqdm | ||||
							
								
								
									
										24
									
								
								scripts/qa_generation/system_prompt.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								scripts/qa_generation/system_prompt.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,24 @@ | ||||
| 你是一名 QA 对生成机器人,你会根据我提供的【心理学书本内容】自动生成合适的 QA 对,要求如下: | ||||
| 
 | ||||
| - 对于我给的文本内容,你需要生成五条这样的 QA 对 | ||||
| - QA 对内容不能重复,答案不能过长 | ||||
| - 用简体中文回答 | ||||
| - 生成的 QA 对需要用 markdown 格式的 json 代码块包裹起来 | ||||
| 
 | ||||
| 以下是参考格式: | ||||
| 
 | ||||
| ```json | ||||
| [ | ||||
| 	{ | ||||
| 		"question": "...", | ||||
| 		"answer": "..." | ||||
| 	}, | ||||
| 	{ | ||||
| 		"question": "...", | ||||
| 		"answer": "..." | ||||
| 	}, | ||||
| 	... | ||||
| ] | ||||
| ``` | ||||
| 
 | ||||
| 以下是给定的文本内容: | ||||
							
								
								
									
										0
									
								
								scripts/qa_generation/util/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								scripts/qa_generation/util/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										75
									
								
								scripts/qa_generation/util/data_loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								scripts/qa_generation/util/data_loader.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,75 @@ | ||||
| import os | ||||
| import re | ||||
| import json | ||||
| from typing import List, Dict | ||||
| 
 | ||||
| from config.config import data_dir | ||||
| from util.logger import get_logger | ||||
| 
 | ||||
| logger = get_logger() | ||||
| 
 | ||||
| """ | ||||
| 递归获取 data_dir 下的所有 .txt 文件列表 | ||||
| """ | ||||
| def get_file_list() -> List[str]: | ||||
|     txt_files = [] | ||||
|     txt_exist_flag = False | ||||
|     for root, dirs, files in os.walk(data_dir): | ||||
|         for file in files: | ||||
|             if file.endswith('.txt'): | ||||
|                 txt_exist_flag = True | ||||
|                 txt_files.append(os.path.join(root, file)) | ||||
| 
 | ||||
|     if not txt_exist_flag: | ||||
|         logger.warning(f'No txt text found in {data_dir}, please check!') | ||||
|     return txt_files | ||||
| 
 | ||||
| """ | ||||
| 获取 txt 文本的所有内容,按句子返回 List | ||||
| file_path: txt 文本路径 | ||||
| window_size: 滑窗大小,单位为句子数 | ||||
| overlap_size: 重叠大小,单位为句子数 | ||||
| """ | ||||
| def get_txt_content( | ||||
|     file_path: str, | ||||
|     window_size: int = 6, | ||||
|     overlap_size: int = 2 | ||||
| ) -> List[str]: | ||||
|     with open(file_path, 'r', encoding='utf-8') as f: | ||||
|         content = f.read().strip() | ||||
| 
 | ||||
|     # 简单实现:按句号、感叹号、问号分割,并去除句内空白符 | ||||
|     sentences = re.split(r'(?<=[。!?])\s+', content) | ||||
|     sentences = [s.replace(' ', '').replace('\t', '') for s in sentences] | ||||
| 
 | ||||
|     # 滑窗 | ||||
|     res = [] | ||||
|     sentences_amount = len(sentences) | ||||
|     start_index, end_index = 0, sentences_amount - window_size | ||||
|     ## check length | ||||
|     if window_size < overlap_size: | ||||
|         logger.error("window_size must be greater than or equal to overlap_size") | ||||
|         return None | ||||
|     if window_size >= sentences_amount: | ||||
|         logger.warning("window_size exceeds the amount of sentences, and the complete text content will be returned") | ||||
|         return ['\n'.join(sentences)] | ||||
|      | ||||
|     for i in range(start_index, end_index + 1, overlap_size): | ||||
|         res.append('\n'.join(sentences[i : i + window_size])) | ||||
|     return res | ||||
| 
 | ||||
| 
 | ||||
| """ | ||||
| 提取返回的 QA 对 | ||||
| """ | ||||
| def capture_qa(content: str) -> List[Dict]: | ||||
|     # 只捕获第一个 json 块 | ||||
|     match = re.search(r'```json(.*?)```', content, re.DOTALL) | ||||
| 
 | ||||
|     if match: | ||||
|         block = match.group(1) | ||||
|         parsed_data = json.loads(block) | ||||
|         return parsed_data | ||||
|     else: | ||||
|         logger.warning("No JSON block found.") | ||||
|         return None | ||||
							
								
								
									
										14
									
								
								scripts/qa_generation/util/logger.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								scripts/qa_generation/util/logger.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | ||||
| from loguru import logger | ||||
| 
 | ||||
| from config.config import log_file_path | ||||
| 
 | ||||
| def get_logger(): | ||||
|     return logger | ||||
| 
 | ||||
| logger.add(log_file_path, rotation="500 MB") | ||||
| 
 | ||||
| logger.configure( | ||||
|     handlers=[ | ||||
|         dict(sink=log_file_path, rotation="500 MB", format="{time} {level} {message}"), | ||||
|     ] | ||||
| ) | ||||
							
								
								
									
										7
									
								
								scripts/qa_generation/util/prompt_loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								scripts/qa_generation/util/prompt_loader.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | ||||
| from config.config import system_prompt_file_path | ||||
| 
 | ||||
| 
 | ||||
| def load_system_prompt() -> str: | ||||
|     with open(system_prompt_file_path, 'r', encoding='utf-8') as f: | ||||
|         system_prompt = f.read() | ||||
|     return system_prompt | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Mxode
						Mxode