Upload QA generation pipeline
This commit is contained in:
parent
54ee4010be
commit
57a9db4c5b
43
scripts/qa_generation/README.md
Normal file
43
scripts/qa_generation/README.md
Normal file
@ -0,0 +1,43 @@
|
||||
# QA Generation Pipeline
|
||||
|
||||
|
||||
|
||||
## 1. 使用方法
|
||||
|
||||
检查 `requirements.txt` 中的依赖是否满足。
|
||||
|
||||
而后,在 `config/config.py` 配置所需的 API KEY,从 `main.py` 启动即可。生成的 QA 对会以 jsonl 的格式存在 `data/generated` 下。
|
||||
|
||||
可以调整 `system_prompt.md`,增强生成的多样性和稳定性。
|
||||
|
||||
### 1.1 API KEY 获取方法
|
||||
|
||||
目前仅包含了 qwen。
|
||||
|
||||
#### 1.1.1 Qwen
|
||||
|
||||
前往[模型服务灵积-API-KEY管理 (aliyun.com)](https://dashscope.console.aliyun.com/apiKey),点击”创建新的 API-KEY“,将获取的 API KEY 填至 `config/config.py` 中的 `DASHSCOPE_API_KEY` 即可。
|
||||
|
||||
|
||||
|
||||
## 2. 注意事项
|
||||
|
||||
### 2.1 系统提示 System Prompt
|
||||
|
||||
注意,目前的解析方案是基于模型会生成 markdown 包裹的 json 块的前提的,更改 system prompt 时需要保证这一点不变。
|
||||
|
||||
### 2.2 滑动窗口 Sliding Window
|
||||
|
||||
滑动窗口的 `window_size` 和 `overlap_size` 都可以在 `util/data_loader.py` 中的 `get_txt_content` 函数中更改。目前是按照句子分割的滑动窗口。
|
||||
|
||||
### 2.3 书本文件格式 Corpus Format
|
||||
|
||||
目前仅支持了 txt 格式,可以将清洗好的书籍文本放在 `data` 文件夹下,程序会递归检索该文件夹下的所有 txt 文件。
|
||||
|
||||
|
||||
|
||||
## TODO
|
||||
|
||||
1. 支持更多模型(Gemini、GPT、ChatGLM……)
|
||||
2. 支持更多文本格式(PDF……)
|
||||
3. 支持更多切分文本的方式
|
0
scripts/qa_generation/config/__init__.py
Normal file
0
scripts/qa_generation/config/__init__.py
Normal file
28
scripts/qa_generation/config/config.py
Normal file
28
scripts/qa_generation/config/config.py
Normal file
@ -0,0 +1,28 @@
|
||||
import os
|
||||
|
||||
"""
|
||||
文件夹路径
|
||||
"""
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
|
||||
base_dir = os.path.dirname(cur_dir) # base
|
||||
|
||||
# model
|
||||
model_dir = os.path.join(base_dir, 'model') # model
|
||||
|
||||
# data
|
||||
data_dir = os.path.join(base_dir, 'data') # data
|
||||
result_dir = os.path.join(data_dir, 'generated') # result
|
||||
|
||||
# log
|
||||
log_dir = os.path.join(base_dir, 'log') # log
|
||||
log_file_path = os.path.join(log_dir, 'log.log') # file
|
||||
|
||||
# system prompt
|
||||
system_prompt_file_path = os.path.join(base_dir, 'system_prompt.md') # system prompt
|
||||
|
||||
|
||||
"""
|
||||
环境变量
|
||||
"""
|
||||
# api-keys
|
||||
DASHSCOPE_API_KEY = 'sk-xxxxxxxx'
|
67
scripts/qa_generation/main.py
Normal file
67
scripts/qa_generation/main.py
Normal file
@ -0,0 +1,67 @@
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
|
||||
from config.config import result_dir
|
||||
from model.qwen import call_qwen_single_turn
|
||||
from util.logger import get_logger
|
||||
from util.data_loader import get_file_list, get_txt_content, capture_qa
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
"""
|
||||
生成 QA 对
|
||||
model_name: 可调用的模型名称,暂时只实现了 qwen
|
||||
interval: 存储间隔,即每隔多少条存一次文件,过密的间隔会增大 IO 开销
|
||||
"""
|
||||
def generate_qa(
|
||||
model_name: str = 'qwen',
|
||||
interval: int = 1,
|
||||
):
|
||||
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
if model_name == 'qwen':
|
||||
model_caller = call_qwen_single_turn
|
||||
else:
|
||||
logger.warning('This model is currently not supported and will call the default model - qwen.')
|
||||
model_caller = call_qwen_single_turn
|
||||
model_name = 'qwen'
|
||||
|
||||
logger.info(f'The called model is: {model_name}.')
|
||||
logger.info(f'The storage interval is: {interval}.')
|
||||
|
||||
file_list = get_file_list()
|
||||
storage_counter = 0
|
||||
storage_list = []
|
||||
for file_name in file_list:
|
||||
contents = get_txt_content(file_name)
|
||||
storage_list = []
|
||||
storage_jsonl_path = os.path.join(result_dir, f'{current_time}-{file_name}-{model_name}.jsonl')
|
||||
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
|
||||
|
||||
for content in tqdm(contents):
|
||||
response = model_caller(content)
|
||||
captured_qa = capture_qa(response)
|
||||
if captured_qa is None:
|
||||
continue
|
||||
|
||||
storage_list.extend(captured_qa)
|
||||
storage_counter += 1
|
||||
if storage_counter % interval == 0:
|
||||
storage_counter = 0
|
||||
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
|
||||
for item in storage_list:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
storage_list = []
|
||||
|
||||
# 如果有剩余,存入
|
||||
if storage_list:
|
||||
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
|
||||
for item in storage_list:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
storage_list = []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate_qa()
|
0
scripts/qa_generation/model/__init__.py
Normal file
0
scripts/qa_generation/model/__init__.py
Normal file
0
scripts/qa_generation/model/gemini.py
Normal file
0
scripts/qa_generation/model/gemini.py
Normal file
0
scripts/qa_generation/model/glm.py
Normal file
0
scripts/qa_generation/model/glm.py
Normal file
0
scripts/qa_generation/model/gpt.py
Normal file
0
scripts/qa_generation/model/gpt.py
Normal file
41
scripts/qa_generation/model/qwen.py
Normal file
41
scripts/qa_generation/model/qwen.py
Normal file
@ -0,0 +1,41 @@
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from dashscope import Generation
|
||||
from dashscope.api_entities.dashscope_response import Role
|
||||
|
||||
from config.config import DASHSCOPE_API_KEY
|
||||
from util.logger import get_logger
|
||||
from util.prompt_loader import load_system_prompt
|
||||
|
||||
|
||||
dashscope.api_key = DASHSCOPE_API_KEY
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def call_qwen_single_turn(query: str) -> str:
|
||||
messages = [
|
||||
{
|
||||
'role': Role.SYSTEM,
|
||||
'content': load_system_prompt()
|
||||
},
|
||||
{
|
||||
'role': Role.USER,
|
||||
'content': query
|
||||
}
|
||||
]
|
||||
response = Generation.call(
|
||||
model='qwen-max-1201',
|
||||
messages=messages,
|
||||
result_format='message',
|
||||
stream=False,
|
||||
incremental_output=False
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
||||
else:
|
||||
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||||
response.request_id, response.status_code,
|
||||
response.code, response.message
|
||||
))
|
||||
return ""
|
3
scripts/qa_generation/requirements.txt
Normal file
3
scripts/qa_generation/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
dashscope
|
||||
loguru
|
||||
tqdm
|
24
scripts/qa_generation/system_prompt.md
Normal file
24
scripts/qa_generation/system_prompt.md
Normal file
@ -0,0 +1,24 @@
|
||||
你是一名 QA 对生成机器人,你会根据我提供的【心理学书本内容】自动生成合适的 QA 对,要求如下:
|
||||
|
||||
- 对于我给的文本内容,你需要生成五条这样的 QA 对
|
||||
- QA 对内容不能重复,答案不能过长
|
||||
- 用简体中文回答
|
||||
- 生成的 QA 对需要用 markdown 格式的 json 代码块包裹起来
|
||||
|
||||
以下是参考格式:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"question": "...",
|
||||
"answer": "..."
|
||||
},
|
||||
{
|
||||
"question": "...",
|
||||
"answer": "..."
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
以下是给定的文本内容:
|
0
scripts/qa_generation/util/__init__.py
Normal file
0
scripts/qa_generation/util/__init__.py
Normal file
75
scripts/qa_generation/util/data_loader.py
Normal file
75
scripts/qa_generation/util/data_loader.py
Normal file
@ -0,0 +1,75 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from typing import List, Dict
|
||||
|
||||
from config.config import data_dir
|
||||
from util.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
"""
|
||||
递归获取 data_dir 下的所有 .txt 文件列表
|
||||
"""
|
||||
def get_file_list() -> List[str]:
|
||||
txt_files = []
|
||||
txt_exist_flag = False
|
||||
for root, dirs, files in os.walk(data_dir):
|
||||
for file in files:
|
||||
if file.endswith('.txt'):
|
||||
txt_exist_flag = True
|
||||
txt_files.append(os.path.join(root, file))
|
||||
|
||||
if not txt_exist_flag:
|
||||
logger.warning(f'No txt text found in {data_dir}, please check!')
|
||||
return txt_files
|
||||
|
||||
"""
|
||||
获取 txt 文本的所有内容,按句子返回 List
|
||||
file_path: txt 文本路径
|
||||
window_size: 滑窗大小,单位为句子数
|
||||
overlap_size: 重叠大小,单位为句子数
|
||||
"""
|
||||
def get_txt_content(
|
||||
file_path: str,
|
||||
window_size: int = 6,
|
||||
overlap_size: int = 2
|
||||
) -> List[str]:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 简单实现:按句号、感叹号、问号分割,并去除句内空白符
|
||||
sentences = re.split(r'(?<=[。!?])\s+', content)
|
||||
sentences = [s.replace(' ', '').replace('\t', '') for s in sentences]
|
||||
|
||||
# 滑窗
|
||||
res = []
|
||||
sentences_amount = len(sentences)
|
||||
start_index, end_index = 0, sentences_amount - window_size
|
||||
## check length
|
||||
if window_size < overlap_size:
|
||||
logger.error("window_size must be greater than or equal to overlap_size")
|
||||
return None
|
||||
if window_size >= sentences_amount:
|
||||
logger.warning("window_size exceeds the amount of sentences, and the complete text content will be returned")
|
||||
return ['\n'.join(sentences)]
|
||||
|
||||
for i in range(start_index, end_index + 1, overlap_size):
|
||||
res.append('\n'.join(sentences[i : i + window_size]))
|
||||
return res
|
||||
|
||||
|
||||
"""
|
||||
提取返回的 QA 对
|
||||
"""
|
||||
def capture_qa(content: str) -> List[Dict]:
|
||||
# 只捕获第一个 json 块
|
||||
match = re.search(r'```json(.*?)```', content, re.DOTALL)
|
||||
|
||||
if match:
|
||||
block = match.group(1)
|
||||
parsed_data = json.loads(block)
|
||||
return parsed_data
|
||||
else:
|
||||
logger.warning("No JSON block found.")
|
||||
return None
|
14
scripts/qa_generation/util/logger.py
Normal file
14
scripts/qa_generation/util/logger.py
Normal file
@ -0,0 +1,14 @@
|
||||
from loguru import logger
|
||||
|
||||
from config.config import log_file_path
|
||||
|
||||
def get_logger():
|
||||
return logger
|
||||
|
||||
logger.add(log_file_path, rotation="500 MB")
|
||||
|
||||
logger.configure(
|
||||
handlers=[
|
||||
dict(sink=log_file_path, rotation="500 MB", format="{time} {level} {message}"),
|
||||
]
|
||||
)
|
7
scripts/qa_generation/util/prompt_loader.py
Normal file
7
scripts/qa_generation/util/prompt_loader.py
Normal file
@ -0,0 +1,7 @@
|
||||
from config.config import system_prompt_file_path
|
||||
|
||||
|
||||
def load_system_prompt() -> str:
|
||||
with open(system_prompt_file_path, 'r', encoding='utf-8') as f:
|
||||
system_prompt = f.read()
|
||||
return system_prompt
|
Loading…
Reference in New Issue
Block a user