Merge pull request #86 from SmartFlowAI/dev

merge clean_qa
This commit is contained in:
xzw 2024-03-16 23:10:14 +08:00 committed by GitHub
commit 60fe587f06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 208 additions and 8 deletions

2
.gitignore vendored
View File

@ -3,6 +3,8 @@ ESConv.json
tmp/
zhipuai/
data/
pdf/
.idea/
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@ -0,0 +1,11 @@
# 清洗 QA 对
调用qwen去判断当前QA对是否属于心理学范畴去除非心理学范畴的 QA 对
## Step 1
1. 准备好需要清洗的 QA 对数据
2. 将该数据放进 model 同级 data 文件夹下
3. 根据文件夹名去修改 config/config.py 中的 judge_dir。我个人没有对文件名进行更改所以我的judge_dir是 judge_dir = os.path.join(data_dir, '数据整合')
## Step 2
1. 运行QA_clean.py即可
2. 清洗完的 QA 对会以 jsonl 的格式存在 data/cleaned 下

View File

@ -0,0 +1,111 @@
import os
import json
import time
from tqdm import tqdm
import concurrent.futures
from datetime import datetime
import numpy as np
from config.config import result_dir, clean_dir, storage_interval, window_size, overlap_size, multi_process_num
from model.qwen import call_qwen_single_turn, call_qwen_Psychology_QA_Pairs
from util.logger import get_logger
from util.data_loader import get_jsonl_file_paths, get_file_list, get_QA_pairs, get_txt_content, capture_qa, merge_sub_qa_generation, save_to_file
logger = get_logger()
def single_thread_generate(thread_num, interval, model_caller, storage_jsonl_path, contents):
storage_counter = 0
judge_list = []
for content in tqdm(contents):
# print('content: ', content)
try:
# model_caller 函数的作用是调用某个预训练的问答生成模型,传递输入内容 content 给模型,然后获取模型的输出 response
response = model_caller(content)
# print('response: ', response)
if response == '1':
content = json.loads(content)
judge_list.append(content)
storage_counter += 1
else:
continue
# 在达到指定的 interval 后,将 storage_list 中的内容保存到指定的文件 storage_jsonl_path 中
if storage_counter % interval == 0:
save_to_file(storage_jsonl_path, judge_list)
storage_counter = 0
judge_list = []
except Exception as exc:
logger.error("QA generation error : %s" % (exc))
# 最后,如果 storage_list 中还有剩余内容,也会将其保存到文件中。
if judge_list:
save_to_file(storage_jsonl_path, judge_list)
judge_list = []
"""
生成 QA
model_name: 可调用的模型名称暂时只实现了 qwen
interval: 存储间隔即每隔多少条存一次文件过密的间隔会增大 IO 开销
"""
def clean_qa(
model_name: str = 'qwen',
interval: int = 10,
):
# current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if model_name == 'qwen':
model_caller = call_qwen_Psychology_QA_Pairs
else:
logger.warning('This model is currently not supported and will call the default model - qwen.')
model_caller = call_qwen_Psychology_QA_Pairs
model_name = 'qwen'
logger.info(f'The called model is: {model_name}.')
logger.info(f'The storage interval is: {interval}.')
file_lists = get_jsonl_file_paths() # 数据整合文件夹下所有.jsonl文件的地址
for file_path in file_lists:
# 一个jsonl文件的所有QA Pairs
contents = get_QA_pairs(file_path)
# print(contents)
file_name = os.path.basename(file_path)
print(file_name)
storage_jsonl_path = os.path.join(
clean_dir, f'{file_name}')
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
contents_array = np.array(contents)
chunks = np.array_split(contents_array, multi_process_num)
# 构建并发参数 list
parameters_list = list()
for thread_num, chunk in enumerate(chunks):
parameters_list.append(
[thread_num, interval, model_caller, storage_jsonl_path, list(chunk)]
)
with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor:
# 循环调用 single_thread_generate 函数,每次赋予参数 parameters
futures = [executor.submit(single_thread_generate, *parameters) for parameters in parameters_list]
for future in concurrent.futures.as_completed(futures):
try:
future.result()
except Exception as exc:
logger.error("Thread generated an exception: %s" % (exc))
merge_sub_qa_generation(result_dir, storage_jsonl_path)
if __name__ == '__main__':
# 创建washed文件夹
os.makedirs('./data/cleaned', exist_ok=True)
clean_qa(interval=storage_interval)

View File

@ -0,0 +1,8 @@
你是一名经验丰富的心理咨询师,熟悉心理学相关知识。根据我提供的 QA 对,来判断这个 QA 对是否属于心理学范畴。
标准如下:
- 若当前 QA 对属于心理学范畴则返回1
- 若当前 QA 对不属于心理学范畴则返回0
以下是给定的心理学 QA 对内容:

View File

@ -10,7 +10,9 @@ base_dir = os.path.dirname(cur_dir) # ba
model_dir = os.path.join(base_dir, 'model') # model
# data
data_dir = os.path.join(base_dir, 'data') # data
data_dir = os.path.join(base_dir, 'data')
clean_dir = os.path.join(data_dir, 'cleaned')
judge_dir = os.path.join(data_dir, '数据整合')
result_dir = os.path.join(data_dir, 'generated') # result
# log
@ -18,7 +20,9 @@ log_dir = os.path.join(base_dir, 'log') # lo
log_file_path = os.path.join(log_dir, 'log.log') # file
# system prompt
# Prompt内容
system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # system prompt
wash_prompt_file_path = os.path.join(base_dir, 'choose_prompt.md')
"""
@ -28,7 +32,6 @@ system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # sy
DASHSCOPE_API_KEY = ''
"""
控制参数
"""
@ -36,3 +39,4 @@ storage_interval = 10
window_size = 8
overlap_size = 2
multi_process_num = 3

View File

@ -5,7 +5,7 @@ from dashscope.api_entities.dashscope_response import Role
from config.config import DASHSCOPE_API_KEY
from util.logger import get_logger
from util.prompt_loader import load_system_prompt
from util.prompt_loader import load_system_prompt, load_wash_prompt
dashscope.api_key = DASHSCOPE_API_KEY
@ -39,3 +39,31 @@ def call_qwen_single_turn(query: str) -> str:
response.code, response.message
))
return ""
def call_qwen_Psychology_QA_Pairs(query: str) -> str:
messages = [
{
'role': Role.SYSTEM,
'content': load_wash_prompt()
},
{
'role': Role.USER,
'content': query
}
]
response = Generation.call(
model='qwen-max-1201',
messages=messages,
result_format='message',
stream=False,
incremental_output=False
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
else:
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
return ""

View File

@ -4,11 +4,39 @@ import json
import glob
from typing import List, Dict
from config.config import data_dir
from config.config import data_dir, judge_dir
from util.logger import get_logger
logger = get_logger()
"""
递归获取 数据整合 下的所有 .jsonl 文件列表
"""
def get_jsonl_file_paths() -> List[str]:
json_file_paths = []
# 遍历根目录及其所有子目录
for dirpath, dirnames, filenames in os.walk(judge_dir):
# 对每个文件进行检查
for filename in filenames:
# 使用正则表达式匹配以.jsonl结尾的文件名
if re.search(r'\.jsonl$', filename):
# 构建完整的文件路径并添加到列表中
json_file_path = os.path.join(dirpath, filename)
json_file_paths.append(json_file_path)
return json_file_paths
def get_QA_pairs(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 按照换行符分割字符串
QA_Pairs = content.split('\n')
return QA_Pairs
"""
递归获取 data_dir 下的所有 .txt 文件列表
"""
@ -47,7 +75,7 @@ def get_txt_content(
res = []
sentences_amount = len(sentences)
start_index, end_index = 0, sentences_amount - window_size
## check length
# check length
if window_size < overlap_size:
logger.error("window_size must be greater than or equal to overlap_size")
return None
@ -56,7 +84,7 @@ def get_txt_content(
return ['\n'.join(sentences)]
for i in range(start_index, end_index + 1, overlap_size):
res.append('\n'.join(sentences[i : i + window_size]))
res.append('\n'.join(sentences[i: i + window_size]))
return res
@ -80,6 +108,7 @@ def capture_qa(content: str) -> List[Dict]:
logger.warning("No JSON block found.")
return None
"""
storage_list 存入到 storage_jsonl_path
"""
@ -88,6 +117,7 @@ def save_to_file(storage_jsonl_path, storage_list):
for item in storage_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
"""
将并发产生的文件合并成为一个文件
"""
@ -102,5 +132,4 @@ def merge_sub_qa_generation(directory, storage_jsonl_path):
for line in f:
file_contents.append(json.loads(line))
os.remove(file_path)
save_to_file(storage_jsonl_path, file_contents)
save_to_file(storage_jsonl_path, file_contents)

View File

@ -1,7 +1,14 @@
from config.config import system_prompt_file_path
from config.config import wash_prompt_file_path
def load_system_prompt() -> str:
with open(system_prompt_file_path, 'r', encoding='utf-8') as f:
system_prompt = f.read()
return system_prompt
def load_wash_prompt() -> str:
with open(wash_prompt_file_path, 'r', encoding='utf-8') as f:
wash_prompt = f.read()
return wash_prompt