Upload QA generation pipeline
This commit is contained in:
parent
54ee4010be
commit
57a9db4c5b
43
scripts/qa_generation/README.md
Normal file
43
scripts/qa_generation/README.md
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# QA Generation Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 1. 使用方法
|
||||||
|
|
||||||
|
检查 `requirements.txt` 中的依赖是否满足。
|
||||||
|
|
||||||
|
而后,在 `config/config.py` 配置所需的 API KEY,从 `main.py` 启动即可。生成的 QA 对会以 jsonl 的格式存在 `data/generated` 下。
|
||||||
|
|
||||||
|
可以调整 `system_prompt.md`,增强生成的多样性和稳定性。
|
||||||
|
|
||||||
|
### 1.1 API KEY 获取方法
|
||||||
|
|
||||||
|
目前仅包含了 qwen。
|
||||||
|
|
||||||
|
#### 1.1.1 Qwen
|
||||||
|
|
||||||
|
前往[模型服务灵积-API-KEY管理 (aliyun.com)](https://dashscope.console.aliyun.com/apiKey),点击”创建新的 API-KEY“,将获取的 API KEY 填至 `config/config.py` 中的 `DASHSCOPE_API_KEY` 即可。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 2. 注意事项
|
||||||
|
|
||||||
|
### 2.1 系统提示 System Prompt
|
||||||
|
|
||||||
|
注意,目前的解析方案是基于模型会生成 markdown 包裹的 json 块的前提的,更改 system prompt 时需要保证这一点不变。
|
||||||
|
|
||||||
|
### 2.2 滑动窗口 Sliding Window
|
||||||
|
|
||||||
|
滑动窗口的 `window_size` 和 `overlap_size` 都可以在 `util/data_loader.py` 中的 `get_txt_content` 函数中更改。目前是按照句子分割的滑动窗口。
|
||||||
|
|
||||||
|
### 2.3 书本文件格式 Corpus Format
|
||||||
|
|
||||||
|
目前仅支持了 txt 格式,可以将清洗好的书籍文本放在 `data` 文件夹下,程序会递归检索该文件夹下的所有 txt 文件。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
1. 支持更多模型(Gemini、GPT、ChatGLM……)
|
||||||
|
2. 支持更多文本格式(PDF……)
|
||||||
|
3. 支持更多切分文本的方式
|
0
scripts/qa_generation/config/__init__.py
Normal file
0
scripts/qa_generation/config/__init__.py
Normal file
28
scripts/qa_generation/config/config.py
Normal file
28
scripts/qa_generation/config/config.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
"""
|
||||||
|
文件夹路径
|
||||||
|
"""
|
||||||
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
|
||||||
|
base_dir = os.path.dirname(cur_dir) # base
|
||||||
|
|
||||||
|
# model
|
||||||
|
model_dir = os.path.join(base_dir, 'model') # model
|
||||||
|
|
||||||
|
# data
|
||||||
|
data_dir = os.path.join(base_dir, 'data') # data
|
||||||
|
result_dir = os.path.join(data_dir, 'generated') # result
|
||||||
|
|
||||||
|
# log
|
||||||
|
log_dir = os.path.join(base_dir, 'log') # log
|
||||||
|
log_file_path = os.path.join(log_dir, 'log.log') # file
|
||||||
|
|
||||||
|
# system prompt
|
||||||
|
system_prompt_file_path = os.path.join(base_dir, 'system_prompt.md') # system prompt
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
环境变量
|
||||||
|
"""
|
||||||
|
# api-keys
|
||||||
|
DASHSCOPE_API_KEY = 'sk-xxxxxxxx'
|
67
scripts/qa_generation/main.py
Normal file
67
scripts/qa_generation/main.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from config.config import result_dir
|
||||||
|
from model.qwen import call_qwen_single_turn
|
||||||
|
from util.logger import get_logger
|
||||||
|
from util.data_loader import get_file_list, get_txt_content, capture_qa
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
"""
|
||||||
|
生成 QA 对
|
||||||
|
model_name: 可调用的模型名称,暂时只实现了 qwen
|
||||||
|
interval: 存储间隔,即每隔多少条存一次文件,过密的间隔会增大 IO 开销
|
||||||
|
"""
|
||||||
|
def generate_qa(
|
||||||
|
model_name: str = 'qwen',
|
||||||
|
interval: int = 1,
|
||||||
|
):
|
||||||
|
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
|
||||||
|
if model_name == 'qwen':
|
||||||
|
model_caller = call_qwen_single_turn
|
||||||
|
else:
|
||||||
|
logger.warning('This model is currently not supported and will call the default model - qwen.')
|
||||||
|
model_caller = call_qwen_single_turn
|
||||||
|
model_name = 'qwen'
|
||||||
|
|
||||||
|
logger.info(f'The called model is: {model_name}.')
|
||||||
|
logger.info(f'The storage interval is: {interval}.')
|
||||||
|
|
||||||
|
file_list = get_file_list()
|
||||||
|
storage_counter = 0
|
||||||
|
storage_list = []
|
||||||
|
for file_name in file_list:
|
||||||
|
contents = get_txt_content(file_name)
|
||||||
|
storage_list = []
|
||||||
|
storage_jsonl_path = os.path.join(result_dir, f'{current_time}-{file_name}-{model_name}.jsonl')
|
||||||
|
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
|
||||||
|
|
||||||
|
for content in tqdm(contents):
|
||||||
|
response = model_caller(content)
|
||||||
|
captured_qa = capture_qa(response)
|
||||||
|
if captured_qa is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
storage_list.extend(captured_qa)
|
||||||
|
storage_counter += 1
|
||||||
|
if storage_counter % interval == 0:
|
||||||
|
storage_counter = 0
|
||||||
|
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
|
||||||
|
for item in storage_list:
|
||||||
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
storage_list = []
|
||||||
|
|
||||||
|
# 如果有剩余,存入
|
||||||
|
if storage_list:
|
||||||
|
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
|
||||||
|
for item in storage_list:
|
||||||
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
storage_list = []
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
generate_qa()
|
0
scripts/qa_generation/model/__init__.py
Normal file
0
scripts/qa_generation/model/__init__.py
Normal file
0
scripts/qa_generation/model/gemini.py
Normal file
0
scripts/qa_generation/model/gemini.py
Normal file
0
scripts/qa_generation/model/glm.py
Normal file
0
scripts/qa_generation/model/glm.py
Normal file
0
scripts/qa_generation/model/gpt.py
Normal file
0
scripts/qa_generation/model/gpt.py
Normal file
41
scripts/qa_generation/model/qwen.py
Normal file
41
scripts/qa_generation/model/qwen.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import dashscope
|
||||||
|
from http import HTTPStatus
|
||||||
|
from dashscope import Generation
|
||||||
|
from dashscope.api_entities.dashscope_response import Role
|
||||||
|
|
||||||
|
from config.config import DASHSCOPE_API_KEY
|
||||||
|
from util.logger import get_logger
|
||||||
|
from util.prompt_loader import load_system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
dashscope.api_key = DASHSCOPE_API_KEY
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def call_qwen_single_turn(query: str) -> str:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
'role': Role.SYSTEM,
|
||||||
|
'content': load_system_prompt()
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'role': Role.USER,
|
||||||
|
'content': query
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = Generation.call(
|
||||||
|
model='qwen-max-1201',
|
||||||
|
messages=messages,
|
||||||
|
result_format='message',
|
||||||
|
stream=False,
|
||||||
|
incremental_output=False
|
||||||
|
)
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
return response.output.choices[0]['message']['content']
|
||||||
|
else:
|
||||||
|
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||||||
|
response.request_id, response.status_code,
|
||||||
|
response.code, response.message
|
||||||
|
))
|
||||||
|
return ""
|
3
scripts/qa_generation/requirements.txt
Normal file
3
scripts/qa_generation/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
dashscope
|
||||||
|
loguru
|
||||||
|
tqdm
|
24
scripts/qa_generation/system_prompt.md
Normal file
24
scripts/qa_generation/system_prompt.md
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
你是一名 QA 对生成机器人,你会根据我提供的【心理学书本内容】自动生成合适的 QA 对,要求如下:
|
||||||
|
|
||||||
|
- 对于我给的文本内容,你需要生成五条这样的 QA 对
|
||||||
|
- QA 对内容不能重复,答案不能过长
|
||||||
|
- 用简体中文回答
|
||||||
|
- 生成的 QA 对需要用 markdown 格式的 json 代码块包裹起来
|
||||||
|
|
||||||
|
以下是参考格式:
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"question": "...",
|
||||||
|
"answer": "..."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "...",
|
||||||
|
"answer": "..."
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
以下是给定的文本内容:
|
0
scripts/qa_generation/util/__init__.py
Normal file
0
scripts/qa_generation/util/__init__.py
Normal file
75
scripts/qa_generation/util/data_loader.py
Normal file
75
scripts/qa_generation/util/data_loader.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from config.config import data_dir
|
||||||
|
from util.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
"""
|
||||||
|
递归获取 data_dir 下的所有 .txt 文件列表
|
||||||
|
"""
|
||||||
|
def get_file_list() -> List[str]:
|
||||||
|
txt_files = []
|
||||||
|
txt_exist_flag = False
|
||||||
|
for root, dirs, files in os.walk(data_dir):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith('.txt'):
|
||||||
|
txt_exist_flag = True
|
||||||
|
txt_files.append(os.path.join(root, file))
|
||||||
|
|
||||||
|
if not txt_exist_flag:
|
||||||
|
logger.warning(f'No txt text found in {data_dir}, please check!')
|
||||||
|
return txt_files
|
||||||
|
|
||||||
|
"""
|
||||||
|
获取 txt 文本的所有内容,按句子返回 List
|
||||||
|
file_path: txt 文本路径
|
||||||
|
window_size: 滑窗大小,单位为句子数
|
||||||
|
overlap_size: 重叠大小,单位为句子数
|
||||||
|
"""
|
||||||
|
def get_txt_content(
|
||||||
|
file_path: str,
|
||||||
|
window_size: int = 6,
|
||||||
|
overlap_size: int = 2
|
||||||
|
) -> List[str]:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read().strip()
|
||||||
|
|
||||||
|
# 简单实现:按句号、感叹号、问号分割,并去除句内空白符
|
||||||
|
sentences = re.split(r'(?<=[。!?])\s+', content)
|
||||||
|
sentences = [s.replace(' ', '').replace('\t', '') for s in sentences]
|
||||||
|
|
||||||
|
# 滑窗
|
||||||
|
res = []
|
||||||
|
sentences_amount = len(sentences)
|
||||||
|
start_index, end_index = 0, sentences_amount - window_size
|
||||||
|
## check length
|
||||||
|
if window_size < overlap_size:
|
||||||
|
logger.error("window_size must be greater than or equal to overlap_size")
|
||||||
|
return None
|
||||||
|
if window_size >= sentences_amount:
|
||||||
|
logger.warning("window_size exceeds the amount of sentences, and the complete text content will be returned")
|
||||||
|
return ['\n'.join(sentences)]
|
||||||
|
|
||||||
|
for i in range(start_index, end_index + 1, overlap_size):
|
||||||
|
res.append('\n'.join(sentences[i : i + window_size]))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
提取返回的 QA 对
|
||||||
|
"""
|
||||||
|
def capture_qa(content: str) -> List[Dict]:
|
||||||
|
# 只捕获第一个 json 块
|
||||||
|
match = re.search(r'```json(.*?)```', content, re.DOTALL)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
block = match.group(1)
|
||||||
|
parsed_data = json.loads(block)
|
||||||
|
return parsed_data
|
||||||
|
else:
|
||||||
|
logger.warning("No JSON block found.")
|
||||||
|
return None
|
14
scripts/qa_generation/util/logger.py
Normal file
14
scripts/qa_generation/util/logger.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from config.config import log_file_path
|
||||||
|
|
||||||
|
def get_logger():
|
||||||
|
return logger
|
||||||
|
|
||||||
|
logger.add(log_file_path, rotation="500 MB")
|
||||||
|
|
||||||
|
logger.configure(
|
||||||
|
handlers=[
|
||||||
|
dict(sink=log_file_path, rotation="500 MB", format="{time} {level} {message}"),
|
||||||
|
]
|
||||||
|
)
|
7
scripts/qa_generation/util/prompt_loader.py
Normal file
7
scripts/qa_generation/util/prompt_loader.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from config.config import system_prompt_file_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_system_prompt() -> str:
|
||||||
|
with open(system_prompt_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
system_prompt = f.read()
|
||||||
|
return system_prompt
|
Loading…
Reference in New Issue
Block a user