commit
60fe587f06
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,6 +3,8 @@ ESConv.json
|
||||
tmp/
|
||||
zhipuai/
|
||||
data/
|
||||
pdf/
|
||||
.idea/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
11
scripts/qa_generation/Clean_QA.md
Normal file
11
scripts/qa_generation/Clean_QA.md
Normal file
@ -0,0 +1,11 @@
|
||||
# 清洗 QA 对
|
||||
调用qwen去判断当前QA对是否属于心理学范畴,去除非心理学范畴的 QA 对
|
||||
|
||||
## Step 1
|
||||
1. 准备好需要清洗的 QA 对数据
|
||||
2. 将该数据放进 model 同级 data 文件夹下
|
||||
3. 根据文件夹名去修改 config/config.py 中的 judge_dir。我个人没有对文件名进行更改,所以我的judge_dir是 judge_dir = os.path.join(data_dir, '数据整合')
|
||||
|
||||
## Step 2
|
||||
1. 运行QA_clean.py即可
|
||||
2. 清洗完的 QA 对会以 jsonl 的格式存在 data/cleaned 下
|
111
scripts/qa_generation/QA_clean.py
Normal file
111
scripts/qa_generation/QA_clean.py
Normal file
@ -0,0 +1,111 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from config.config import result_dir, clean_dir, storage_interval, window_size, overlap_size, multi_process_num
|
||||
from model.qwen import call_qwen_single_turn, call_qwen_Psychology_QA_Pairs
|
||||
from util.logger import get_logger
|
||||
from util.data_loader import get_jsonl_file_paths, get_file_list, get_QA_pairs, get_txt_content, capture_qa, merge_sub_qa_generation, save_to_file
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def single_thread_generate(thread_num, interval, model_caller, storage_jsonl_path, contents):
|
||||
|
||||
storage_counter = 0
|
||||
judge_list = []
|
||||
for content in tqdm(contents):
|
||||
# print('content: ', content)
|
||||
try:
|
||||
# model_caller 函数的作用是调用某个预训练的问答生成模型,传递输入内容 content 给模型,然后获取模型的输出 response
|
||||
response = model_caller(content)
|
||||
# print('response: ', response)
|
||||
|
||||
if response == '1':
|
||||
content = json.loads(content)
|
||||
judge_list.append(content)
|
||||
storage_counter += 1
|
||||
else:
|
||||
continue
|
||||
|
||||
# 在达到指定的 interval 后,将 storage_list 中的内容保存到指定的文件 storage_jsonl_path 中
|
||||
if storage_counter % interval == 0:
|
||||
save_to_file(storage_jsonl_path, judge_list)
|
||||
storage_counter = 0
|
||||
judge_list = []
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("QA generation error : %s" % (exc))
|
||||
|
||||
# 最后,如果 storage_list 中还有剩余内容,也会将其保存到文件中。
|
||||
if judge_list:
|
||||
save_to_file(storage_jsonl_path, judge_list)
|
||||
judge_list = []
|
||||
|
||||
|
||||
"""
|
||||
生成 QA 对
|
||||
model_name: 可调用的模型名称,暂时只实现了 qwen
|
||||
interval: 存储间隔,即每隔多少条存一次文件,过密的间隔会增大 IO 开销
|
||||
"""
|
||||
def clean_qa(
|
||||
model_name: str = 'qwen',
|
||||
interval: int = 10,
|
||||
):
|
||||
# current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
if model_name == 'qwen':
|
||||
model_caller = call_qwen_Psychology_QA_Pairs
|
||||
else:
|
||||
logger.warning('This model is currently not supported and will call the default model - qwen.')
|
||||
model_caller = call_qwen_Psychology_QA_Pairs
|
||||
model_name = 'qwen'
|
||||
|
||||
logger.info(f'The called model is: {model_name}.')
|
||||
logger.info(f'The storage interval is: {interval}.')
|
||||
|
||||
file_lists = get_jsonl_file_paths() # 数据整合文件夹下所有.jsonl文件的地址
|
||||
|
||||
for file_path in file_lists:
|
||||
# 一个jsonl文件的所有QA Pairs
|
||||
contents = get_QA_pairs(file_path)
|
||||
# print(contents)
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
print(file_name)
|
||||
storage_jsonl_path = os.path.join(
|
||||
clean_dir, f'{file_name}')
|
||||
|
||||
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
|
||||
|
||||
contents_array = np.array(contents)
|
||||
chunks = np.array_split(contents_array, multi_process_num)
|
||||
|
||||
# 构建并发参数 list
|
||||
parameters_list = list()
|
||||
for thread_num, chunk in enumerate(chunks):
|
||||
parameters_list.append(
|
||||
[thread_num, interval, model_caller, storage_jsonl_path, list(chunk)]
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor:
|
||||
# 循环调用 single_thread_generate 函数,每次赋予参数 parameters
|
||||
futures = [executor.submit(single_thread_generate, *parameters) for parameters in parameters_list]
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as exc:
|
||||
logger.error("Thread generated an exception: %s" % (exc))
|
||||
|
||||
merge_sub_qa_generation(result_dir, storage_jsonl_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 创建washed文件夹
|
||||
os.makedirs('./data/cleaned', exist_ok=True)
|
||||
clean_qa(interval=storage_interval)
|
8
scripts/qa_generation/choose_prompt.md
Normal file
8
scripts/qa_generation/choose_prompt.md
Normal file
@ -0,0 +1,8 @@
|
||||
你是一名经验丰富的心理咨询师,熟悉心理学相关知识。根据我提供的 QA 对,来判断这个 QA 对是否属于心理学范畴。
|
||||
|
||||
标准如下:
|
||||
- 若当前 QA 对属于心理学范畴,则返回1
|
||||
- 若当前 QA 对不属于心理学范畴,则返回0
|
||||
|
||||
|
||||
以下是给定的心理学 QA 对内容:
|
@ -10,7 +10,9 @@ base_dir = os.path.dirname(cur_dir) # ba
|
||||
model_dir = os.path.join(base_dir, 'model') # model
|
||||
|
||||
# data
|
||||
data_dir = os.path.join(base_dir, 'data') # data
|
||||
data_dir = os.path.join(base_dir, 'data')
|
||||
clean_dir = os.path.join(data_dir, 'cleaned')
|
||||
judge_dir = os.path.join(data_dir, '数据整合')
|
||||
result_dir = os.path.join(data_dir, 'generated') # result
|
||||
|
||||
# log
|
||||
@ -18,7 +20,9 @@ log_dir = os.path.join(base_dir, 'log') # lo
|
||||
log_file_path = os.path.join(log_dir, 'log.log') # file
|
||||
|
||||
# system prompt
|
||||
# Prompt内容
|
||||
system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # system prompt
|
||||
wash_prompt_file_path = os.path.join(base_dir, 'choose_prompt.md')
|
||||
|
||||
|
||||
"""
|
||||
@ -28,7 +32,6 @@ system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # sy
|
||||
DASHSCOPE_API_KEY = ''
|
||||
|
||||
|
||||
|
||||
"""
|
||||
控制参数
|
||||
"""
|
||||
@ -36,3 +39,4 @@ storage_interval = 10
|
||||
window_size = 8
|
||||
overlap_size = 2
|
||||
multi_process_num = 3
|
||||
|
||||
|
@ -5,7 +5,7 @@ from dashscope.api_entities.dashscope_response import Role
|
||||
|
||||
from config.config import DASHSCOPE_API_KEY
|
||||
from util.logger import get_logger
|
||||
from util.prompt_loader import load_system_prompt
|
||||
from util.prompt_loader import load_system_prompt, load_wash_prompt
|
||||
|
||||
|
||||
dashscope.api_key = DASHSCOPE_API_KEY
|
||||
@ -39,3 +39,31 @@ def call_qwen_single_turn(query: str) -> str:
|
||||
response.code, response.message
|
||||
))
|
||||
return ""
|
||||
|
||||
|
||||
def call_qwen_Psychology_QA_Pairs(query: str) -> str:
|
||||
messages = [
|
||||
{
|
||||
'role': Role.SYSTEM,
|
||||
'content': load_wash_prompt()
|
||||
},
|
||||
{
|
||||
'role': Role.USER,
|
||||
'content': query
|
||||
}
|
||||
]
|
||||
response = Generation.call(
|
||||
model='qwen-max-1201',
|
||||
messages=messages,
|
||||
result_format='message',
|
||||
stream=False,
|
||||
incremental_output=False
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
||||
else:
|
||||
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||||
response.request_id, response.status_code,
|
||||
response.code, response.message
|
||||
))
|
||||
return ""
|
||||
|
@ -4,11 +4,39 @@ import json
|
||||
import glob
|
||||
from typing import List, Dict
|
||||
|
||||
from config.config import data_dir
|
||||
from config.config import data_dir, judge_dir
|
||||
from util.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
"""
|
||||
递归获取 数据整合 下的所有 .jsonl 文件列表
|
||||
"""
|
||||
def get_jsonl_file_paths() -> List[str]:
|
||||
json_file_paths = []
|
||||
|
||||
# 遍历根目录及其所有子目录
|
||||
for dirpath, dirnames, filenames in os.walk(judge_dir):
|
||||
# 对每个文件进行检查
|
||||
for filename in filenames:
|
||||
# 使用正则表达式匹配以.jsonl结尾的文件名
|
||||
if re.search(r'\.jsonl$', filename):
|
||||
# 构建完整的文件路径并添加到列表中
|
||||
json_file_path = os.path.join(dirpath, filename)
|
||||
json_file_paths.append(json_file_path)
|
||||
|
||||
return json_file_paths
|
||||
|
||||
def get_QA_pairs(json_path):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 按照换行符分割字符串
|
||||
QA_Pairs = content.split('\n')
|
||||
|
||||
return QA_Pairs
|
||||
|
||||
"""
|
||||
递归获取 data_dir 下的所有 .txt 文件列表
|
||||
"""
|
||||
@ -47,7 +75,7 @@ def get_txt_content(
|
||||
res = []
|
||||
sentences_amount = len(sentences)
|
||||
start_index, end_index = 0, sentences_amount - window_size
|
||||
## check length
|
||||
# check length
|
||||
if window_size < overlap_size:
|
||||
logger.error("window_size must be greater than or equal to overlap_size")
|
||||
return None
|
||||
@ -56,7 +84,7 @@ def get_txt_content(
|
||||
return ['\n'.join(sentences)]
|
||||
|
||||
for i in range(start_index, end_index + 1, overlap_size):
|
||||
res.append('\n'.join(sentences[i : i + window_size]))
|
||||
res.append('\n'.join(sentences[i: i + window_size]))
|
||||
return res
|
||||
|
||||
|
||||
@ -80,6 +108,7 @@ def capture_qa(content: str) -> List[Dict]:
|
||||
logger.warning("No JSON block found.")
|
||||
return None
|
||||
|
||||
|
||||
"""
|
||||
将 storage_list 存入到 storage_jsonl_path
|
||||
"""
|
||||
@ -88,6 +117,7 @@ def save_to_file(storage_jsonl_path, storage_list):
|
||||
for item in storage_list:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
"""
|
||||
将并发产生的文件合并成为一个文件
|
||||
"""
|
||||
@ -102,5 +132,4 @@ def merge_sub_qa_generation(directory, storage_jsonl_path):
|
||||
for line in f:
|
||||
file_contents.append(json.loads(line))
|
||||
os.remove(file_path)
|
||||
save_to_file(storage_jsonl_path, file_contents)
|
||||
|
||||
save_to_file(storage_jsonl_path, file_contents)
|
@ -1,7 +1,14 @@
|
||||
from config.config import system_prompt_file_path
|
||||
from config.config import wash_prompt_file_path
|
||||
|
||||
|
||||
def load_system_prompt() -> str:
|
||||
with open(system_prompt_file_path, 'r', encoding='utf-8') as f:
|
||||
system_prompt = f.read()
|
||||
return system_prompt
|
||||
|
||||
|
||||
def load_wash_prompt() -> str:
|
||||
with open(wash_prompt_file_path, 'r', encoding='utf-8') as f:
|
||||
wash_prompt = f.read()
|
||||
return wash_prompt
|
||||
|
Loading…
Reference in New Issue
Block a user