Upload QA generation pipeline

This commit is contained in:
Mxode 2024-03-07 17:56:07 +08:00
parent 54ee4010be
commit 57a9db4c5b
15 changed files with 302 additions and 0 deletions

View File

@ -0,0 +1,43 @@
# QA Generation Pipeline
## 1. 使用方法
检查 `requirements.txt` 中的依赖是否满足。
而后,在 `config/config.py` 配置所需的 API KEY`main.py` 启动即可。生成的 QA 对会以 jsonl 的格式存在 `data/generated` 下。
可以调整 `system_prompt.md`,增强生成的多样性和稳定性。
### 1.1 API KEY 获取方法
目前仅包含了 qwen。
#### 1.1.1 Qwen
前往[模型服务灵积-API-KEY管理 (aliyun.com)](https://dashscope.console.aliyun.com/apiKey),点击”创建新的 API-KEY“将获取的 API KEY 填至 `config/config.py` 中的 `DASHSCOPE_API_KEY` 即可。
## 2. 注意事项
### 2.1 系统提示 System Prompt
注意,目前的解析方案是基于模型会生成 markdown 包裹的 json 块的前提的,更改 system prompt 时需要保证这一点不变。
### 2.2 滑动窗口 Sliding Window
滑动窗口的 `window_size``overlap_size` 都可以在 `util/data_loader.py` 中的 `get_txt_content` 函数中更改。目前是按照句子分割的滑动窗口。
### 2.3 书本文件格式 Corpus Format
目前仅支持了 txt 格式,可以将清洗好的书籍文本放在 `data` 文件夹下,程序会递归检索该文件夹下的所有 txt 文件。
## TODO
1. 支持更多模型Gemini、GPT、ChatGLM……
2. 支持更多文本格式PDF……
3. 支持更多切分文本的方式

View File

View File

@ -0,0 +1,28 @@
import os
"""
文件夹路径
"""
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
base_dir = os.path.dirname(cur_dir) # base
# model
model_dir = os.path.join(base_dir, 'model') # model
# data
data_dir = os.path.join(base_dir, 'data') # data
result_dir = os.path.join(data_dir, 'generated') # result
# log
log_dir = os.path.join(base_dir, 'log') # log
log_file_path = os.path.join(log_dir, 'log.log') # file
# system prompt
system_prompt_file_path = os.path.join(base_dir, 'system_prompt.md') # system prompt
"""
环境变量
"""
# api-keys
DASHSCOPE_API_KEY = 'sk-xxxxxxxx'

View File

@ -0,0 +1,67 @@
import os
import json
from tqdm import tqdm
from datetime import datetime
from config.config import result_dir
from model.qwen import call_qwen_single_turn
from util.logger import get_logger
from util.data_loader import get_file_list, get_txt_content, capture_qa
logger = get_logger()
"""
生成 QA
model_name: 可调用的模型名称暂时只实现了 qwen
interval: 存储间隔即每隔多少条存一次文件过密的间隔会增大 IO 开销
"""
def generate_qa(
model_name: str = 'qwen',
interval: int = 1,
):
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if model_name == 'qwen':
model_caller = call_qwen_single_turn
else:
logger.warning('This model is currently not supported and will call the default model - qwen.')
model_caller = call_qwen_single_turn
model_name = 'qwen'
logger.info(f'The called model is: {model_name}.')
logger.info(f'The storage interval is: {interval}.')
file_list = get_file_list()
storage_counter = 0
storage_list = []
for file_name in file_list:
contents = get_txt_content(file_name)
storage_list = []
storage_jsonl_path = os.path.join(result_dir, f'{current_time}-{file_name}-{model_name}.jsonl')
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
for content in tqdm(contents):
response = model_caller(content)
captured_qa = capture_qa(response)
if captured_qa is None:
continue
storage_list.extend(captured_qa)
storage_counter += 1
if storage_counter % interval == 0:
storage_counter = 0
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
for item in storage_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
storage_list = []
# 如果有剩余,存入
if storage_list:
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
for item in storage_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
storage_list = []
if __name__ == '__main__':
generate_qa()

View File

View File

View File

View File

View File

@ -0,0 +1,41 @@
import dashscope
from http import HTTPStatus
from dashscope import Generation
from dashscope.api_entities.dashscope_response import Role
from config.config import DASHSCOPE_API_KEY
from util.logger import get_logger
from util.prompt_loader import load_system_prompt
dashscope.api_key = DASHSCOPE_API_KEY
logger = get_logger()
def call_qwen_single_turn(query: str) -> str:
messages = [
{
'role': Role.SYSTEM,
'content': load_system_prompt()
},
{
'role': Role.USER,
'content': query
}
]
response = Generation.call(
model='qwen-max-1201',
messages=messages,
result_format='message',
stream=False,
incremental_output=False
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
else:
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
return ""

View File

@ -0,0 +1,3 @@
dashscope
loguru
tqdm

View File

@ -0,0 +1,24 @@
你是一名 QA 对生成机器人,你会根据我提供的【心理学书本内容】自动生成合适的 QA 对,要求如下:
- 对于我给的文本内容,你需要生成五条这样的 QA 对
- QA 对内容不能重复,答案不能过长
- 用简体中文回答
- 生成的 QA 对需要用 markdown 格式的 json 代码块包裹起来
以下是参考格式:
```json
[
{
"question": "...",
"answer": "..."
},
{
"question": "...",
"answer": "..."
},
...
]
```
以下是给定的文本内容:

View File

View File

@ -0,0 +1,75 @@
import os
import re
import json
from typing import List, Dict
from config.config import data_dir
from util.logger import get_logger
logger = get_logger()
"""
递归获取 data_dir 下的所有 .txt 文件列表
"""
def get_file_list() -> List[str]:
txt_files = []
txt_exist_flag = False
for root, dirs, files in os.walk(data_dir):
for file in files:
if file.endswith('.txt'):
txt_exist_flag = True
txt_files.append(os.path.join(root, file))
if not txt_exist_flag:
logger.warning(f'No txt text found in {data_dir}, please check!')
return txt_files
"""
获取 txt 文本的所有内容按句子返回 List
file_path: txt 文本路径
window_size: 滑窗大小单位为句子数
overlap_size: 重叠大小单位为句子数
"""
def get_txt_content(
file_path: str,
window_size: int = 6,
overlap_size: int = 2
) -> List[str]:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 简单实现:按句号、感叹号、问号分割,并去除句内空白符
sentences = re.split(r'(?<=[。!?])\s+', content)
sentences = [s.replace(' ', '').replace('\t', '') for s in sentences]
# 滑窗
res = []
sentences_amount = len(sentences)
start_index, end_index = 0, sentences_amount - window_size
## check length
if window_size < overlap_size:
logger.error("window_size must be greater than or equal to overlap_size")
return None
if window_size >= sentences_amount:
logger.warning("window_size exceeds the amount of sentences, and the complete text content will be returned")
return ['\n'.join(sentences)]
for i in range(start_index, end_index + 1, overlap_size):
res.append('\n'.join(sentences[i : i + window_size]))
return res
"""
提取返回的 QA
"""
def capture_qa(content: str) -> List[Dict]:
# 只捕获第一个 json 块
match = re.search(r'```json(.*?)```', content, re.DOTALL)
if match:
block = match.group(1)
parsed_data = json.loads(block)
return parsed_data
else:
logger.warning("No JSON block found.")
return None

View File

@ -0,0 +1,14 @@
from loguru import logger
from config.config import log_file_path
def get_logger():
return logger
logger.add(log_file_path, rotation="500 MB")
logger.configure(
handlers=[
dict(sink=log_file_path, rotation="500 MB", format="{time} {level} {message}"),
]
)

View File

@ -0,0 +1,7 @@
from config.config import system_prompt_file_path
def load_system_prompt() -> str:
with open(system_prompt_file_path, 'r', encoding='utf-8') as f:
system_prompt = f.read()
return system_prompt