diff --git a/scripts/qa_generation/README.md b/scripts/qa_generation/README.md new file mode 100644 index 0000000..030ec70 --- /dev/null +++ b/scripts/qa_generation/README.md @@ -0,0 +1,43 @@ +# QA Generation Pipeline + + + +## 1. 使用方法 + +检查 `requirements.txt` 中的依赖是否满足。 + +而后,在 `config/config.py` 配置所需的 API KEY,从 `main.py` 启动即可。生成的 QA 对会以 jsonl 的格式存在 `data/generated` 下。 + +可以调整 `system_prompt.md`,增强生成的多样性和稳定性。 + +### 1.1 API KEY 获取方法 + +目前仅包含了 qwen。 + +#### 1.1.1 Qwen + +前往[模型服务灵积-API-KEY管理 (aliyun.com)](https://dashscope.console.aliyun.com/apiKey),点击”创建新的 API-KEY“,将获取的 API KEY 填至 `config/config.py` 中的 `DASHSCOPE_API_KEY` 即可。 + + + +## 2. 注意事项 + +### 2.1 系统提示 System Prompt + +注意,目前的解析方案是基于模型会生成 markdown 包裹的 json 块的前提的,更改 system prompt 时需要保证这一点不变。 + +### 2.2 滑动窗口 Sliding Window + +滑动窗口的 `window_size` 和 `overlap_size` 都可以在 `util/data_loader.py` 中的 `get_txt_content` 函数中更改。目前是按照句子分割的滑动窗口。 + +### 2.3 书本文件格式 Corpus Format + +目前仅支持了 txt 格式,可以将清洗好的书籍文本放在 `data` 文件夹下,程序会递归检索该文件夹下的所有 txt 文件。 + + + +## TODO + +1. 支持更多模型(Gemini、GPT、ChatGLM……) +2. 支持更多文本格式(PDF……) +3. 支持更多切分文本的方式 \ No newline at end of file diff --git a/scripts/qa_generation/config/__init__.py b/scripts/qa_generation/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/qa_generation/config/config.py b/scripts/qa_generation/config/config.py new file mode 100644 index 0000000..700a96a --- /dev/null +++ b/scripts/qa_generation/config/config.py @@ -0,0 +1,28 @@ +import os + +""" +文件夹路径 +""" +cur_dir = os.path.dirname(os.path.abspath(__file__)) # config +base_dir = os.path.dirname(cur_dir) # base + +# model +model_dir = os.path.join(base_dir, 'model') # model + +# data +data_dir = os.path.join(base_dir, 'data') # data +result_dir = os.path.join(data_dir, 'generated') # result + +# log +log_dir = os.path.join(base_dir, 'log') # log +log_file_path = os.path.join(log_dir, 'log.log') # file + +# system prompt +system_prompt_file_path = os.path.join(base_dir, 'system_prompt.md') # system prompt + + +""" +环境变量 +""" +# api-keys +DASHSCOPE_API_KEY = 'sk-xxxxxxxx' diff --git a/scripts/qa_generation/main.py b/scripts/qa_generation/main.py new file mode 100644 index 0000000..1eb73e9 --- /dev/null +++ b/scripts/qa_generation/main.py @@ -0,0 +1,67 @@ +import os +import json +from tqdm import tqdm +from datetime import datetime + +from config.config import result_dir +from model.qwen import call_qwen_single_turn +from util.logger import get_logger +from util.data_loader import get_file_list, get_txt_content, capture_qa + +logger = get_logger() + +""" +生成 QA 对 +model_name: 可调用的模型名称,暂时只实现了 qwen +interval: 存储间隔,即每隔多少条存一次文件,过密的间隔会增大 IO 开销 +""" +def generate_qa( + model_name: str = 'qwen', + interval: int = 1, +): + current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + if model_name == 'qwen': + model_caller = call_qwen_single_turn + else: + logger.warning('This model is currently not supported and will call the default model - qwen.') + model_caller = call_qwen_single_turn + model_name = 'qwen' + + logger.info(f'The called model is: {model_name}.') + logger.info(f'The storage interval is: {interval}.') + + file_list = get_file_list() + storage_counter = 0 + storage_list = [] + for file_name in file_list: + contents = get_txt_content(file_name) + storage_list = [] + storage_jsonl_path = os.path.join(result_dir, f'{current_time}-{file_name}-{model_name}.jsonl') + logger.info(f'The generated QA will be stored in {storage_jsonl_path}.') + + for content in tqdm(contents): + response = model_caller(content) + captured_qa = capture_qa(response) + if captured_qa is None: + continue + + storage_list.extend(captured_qa) + storage_counter += 1 + if storage_counter % interval == 0: + storage_counter = 0 + with open(storage_jsonl_path, 'a', encoding='utf-8') as f: + for item in storage_list: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + storage_list = [] + + # 如果有剩余,存入 + if storage_list: + with open(storage_jsonl_path, 'a', encoding='utf-8') as f: + for item in storage_list: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + storage_list = [] + + +if __name__ == '__main__': + generate_qa() diff --git a/scripts/qa_generation/model/__init__.py b/scripts/qa_generation/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/qa_generation/model/gemini.py b/scripts/qa_generation/model/gemini.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/qa_generation/model/glm.py b/scripts/qa_generation/model/glm.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/qa_generation/model/gpt.py b/scripts/qa_generation/model/gpt.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/qa_generation/model/qwen.py b/scripts/qa_generation/model/qwen.py new file mode 100644 index 0000000..ed27c4a --- /dev/null +++ b/scripts/qa_generation/model/qwen.py @@ -0,0 +1,41 @@ +import dashscope +from http import HTTPStatus +from dashscope import Generation +from dashscope.api_entities.dashscope_response import Role + +from config.config import DASHSCOPE_API_KEY +from util.logger import get_logger +from util.prompt_loader import load_system_prompt + + +dashscope.api_key = DASHSCOPE_API_KEY + +logger = get_logger() + + +def call_qwen_single_turn(query: str) -> str: + messages = [ + { + 'role': Role.SYSTEM, + 'content': load_system_prompt() + }, + { + 'role': Role.USER, + 'content': query + } + ] + response = Generation.call( + model='qwen-max-1201', + messages=messages, + result_format='message', + stream=False, + incremental_output=False + ) + if response.status_code == HTTPStatus.OK: + return response.output.choices[0]['message']['content'] + else: + logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( + response.request_id, response.status_code, + response.code, response.message + )) + return "" diff --git a/scripts/qa_generation/requirements.txt b/scripts/qa_generation/requirements.txt new file mode 100644 index 0000000..bc615fa --- /dev/null +++ b/scripts/qa_generation/requirements.txt @@ -0,0 +1,3 @@ +dashscope +loguru +tqdm \ No newline at end of file diff --git a/scripts/qa_generation/system_prompt.md b/scripts/qa_generation/system_prompt.md new file mode 100644 index 0000000..5fa6efb --- /dev/null +++ b/scripts/qa_generation/system_prompt.md @@ -0,0 +1,24 @@ +你是一名 QA 对生成机器人,你会根据我提供的【心理学书本内容】自动生成合适的 QA 对,要求如下: + +- 对于我给的文本内容,你需要生成五条这样的 QA 对 +- QA 对内容不能重复,答案不能过长 +- 用简体中文回答 +- 生成的 QA 对需要用 markdown 格式的 json 代码块包裹起来 + +以下是参考格式: + +```json +[ + { + "question": "...", + "answer": "..." + }, + { + "question": "...", + "answer": "..." + }, + ... +] +``` + +以下是给定的文本内容: diff --git a/scripts/qa_generation/util/__init__.py b/scripts/qa_generation/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/qa_generation/util/data_loader.py b/scripts/qa_generation/util/data_loader.py new file mode 100644 index 0000000..f739b66 --- /dev/null +++ b/scripts/qa_generation/util/data_loader.py @@ -0,0 +1,75 @@ +import os +import re +import json +from typing import List, Dict + +from config.config import data_dir +from util.logger import get_logger + +logger = get_logger() + +""" +递归获取 data_dir 下的所有 .txt 文件列表 +""" +def get_file_list() -> List[str]: + txt_files = [] + txt_exist_flag = False + for root, dirs, files in os.walk(data_dir): + for file in files: + if file.endswith('.txt'): + txt_exist_flag = True + txt_files.append(os.path.join(root, file)) + + if not txt_exist_flag: + logger.warning(f'No txt text found in {data_dir}, please check!') + return txt_files + +""" +获取 txt 文本的所有内容,按句子返回 List +file_path: txt 文本路径 +window_size: 滑窗大小,单位为句子数 +overlap_size: 重叠大小,单位为句子数 +""" +def get_txt_content( + file_path: str, + window_size: int = 6, + overlap_size: int = 2 +) -> List[str]: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read().strip() + + # 简单实现:按句号、感叹号、问号分割,并去除句内空白符 + sentences = re.split(r'(?<=[。!?])\s+', content) + sentences = [s.replace(' ', '').replace('\t', '') for s in sentences] + + # 滑窗 + res = [] + sentences_amount = len(sentences) + start_index, end_index = 0, sentences_amount - window_size + ## check length + if window_size < overlap_size: + logger.error("window_size must be greater than or equal to overlap_size") + return None + if window_size >= sentences_amount: + logger.warning("window_size exceeds the amount of sentences, and the complete text content will be returned") + return ['\n'.join(sentences)] + + for i in range(start_index, end_index + 1, overlap_size): + res.append('\n'.join(sentences[i : i + window_size])) + return res + + +""" +提取返回的 QA 对 +""" +def capture_qa(content: str) -> List[Dict]: + # 只捕获第一个 json 块 + match = re.search(r'```json(.*?)```', content, re.DOTALL) + + if match: + block = match.group(1) + parsed_data = json.loads(block) + return parsed_data + else: + logger.warning("No JSON block found.") + return None diff --git a/scripts/qa_generation/util/logger.py b/scripts/qa_generation/util/logger.py new file mode 100644 index 0000000..430f126 --- /dev/null +++ b/scripts/qa_generation/util/logger.py @@ -0,0 +1,14 @@ +from loguru import logger + +from config.config import log_file_path + +def get_logger(): + return logger + +logger.add(log_file_path, rotation="500 MB") + +logger.configure( + handlers=[ + dict(sink=log_file_path, rotation="500 MB", format="{time} {level} {message}"), + ] +) diff --git a/scripts/qa_generation/util/prompt_loader.py b/scripts/qa_generation/util/prompt_loader.py new file mode 100644 index 0000000..1503dea --- /dev/null +++ b/scripts/qa_generation/util/prompt_loader.py @@ -0,0 +1,7 @@ +from config.config import system_prompt_file_path + + +def load_system_prompt() -> str: + with open(system_prompt_file_path, 'r', encoding='utf-8') as f: + system_prompt = f.read() + return system_prompt