commit
60fe587f06
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,6 +3,8 @@ ESConv.json
|
|||||||
tmp/
|
tmp/
|
||||||
zhipuai/
|
zhipuai/
|
||||||
data/
|
data/
|
||||||
|
pdf/
|
||||||
|
.idea/
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
11
scripts/qa_generation/Clean_QA.md
Normal file
11
scripts/qa_generation/Clean_QA.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# 清洗 QA 对
|
||||||
|
调用qwen去判断当前QA对是否属于心理学范畴,去除非心理学范畴的 QA 对
|
||||||
|
|
||||||
|
## Step 1
|
||||||
|
1. 准备好需要清洗的 QA 对数据
|
||||||
|
2. 将该数据放进 model 同级 data 文件夹下
|
||||||
|
3. 根据文件夹名去修改 config/config.py 中的 judge_dir。我个人没有对文件名进行更改,所以我的judge_dir是 judge_dir = os.path.join(data_dir, '数据整合')
|
||||||
|
|
||||||
|
## Step 2
|
||||||
|
1. 运行QA_clean.py即可
|
||||||
|
2. 清洗完的 QA 对会以 jsonl 的格式存在 data/cleaned 下
|
111
scripts/qa_generation/QA_clean.py
Normal file
111
scripts/qa_generation/QA_clean.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
import concurrent.futures
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from config.config import result_dir, clean_dir, storage_interval, window_size, overlap_size, multi_process_num
|
||||||
|
from model.qwen import call_qwen_single_turn, call_qwen_Psychology_QA_Pairs
|
||||||
|
from util.logger import get_logger
|
||||||
|
from util.data_loader import get_jsonl_file_paths, get_file_list, get_QA_pairs, get_txt_content, capture_qa, merge_sub_qa_generation, save_to_file
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def single_thread_generate(thread_num, interval, model_caller, storage_jsonl_path, contents):
|
||||||
|
|
||||||
|
storage_counter = 0
|
||||||
|
judge_list = []
|
||||||
|
for content in tqdm(contents):
|
||||||
|
# print('content: ', content)
|
||||||
|
try:
|
||||||
|
# model_caller 函数的作用是调用某个预训练的问答生成模型,传递输入内容 content 给模型,然后获取模型的输出 response
|
||||||
|
response = model_caller(content)
|
||||||
|
# print('response: ', response)
|
||||||
|
|
||||||
|
if response == '1':
|
||||||
|
content = json.loads(content)
|
||||||
|
judge_list.append(content)
|
||||||
|
storage_counter += 1
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 在达到指定的 interval 后,将 storage_list 中的内容保存到指定的文件 storage_jsonl_path 中
|
||||||
|
if storage_counter % interval == 0:
|
||||||
|
save_to_file(storage_jsonl_path, judge_list)
|
||||||
|
storage_counter = 0
|
||||||
|
judge_list = []
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("QA generation error : %s" % (exc))
|
||||||
|
|
||||||
|
# 最后,如果 storage_list 中还有剩余内容,也会将其保存到文件中。
|
||||||
|
if judge_list:
|
||||||
|
save_to_file(storage_jsonl_path, judge_list)
|
||||||
|
judge_list = []
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
生成 QA 对
|
||||||
|
model_name: 可调用的模型名称,暂时只实现了 qwen
|
||||||
|
interval: 存储间隔,即每隔多少条存一次文件,过密的间隔会增大 IO 开销
|
||||||
|
"""
|
||||||
|
def clean_qa(
|
||||||
|
model_name: str = 'qwen',
|
||||||
|
interval: int = 10,
|
||||||
|
):
|
||||||
|
# current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
|
||||||
|
if model_name == 'qwen':
|
||||||
|
model_caller = call_qwen_Psychology_QA_Pairs
|
||||||
|
else:
|
||||||
|
logger.warning('This model is currently not supported and will call the default model - qwen.')
|
||||||
|
model_caller = call_qwen_Psychology_QA_Pairs
|
||||||
|
model_name = 'qwen'
|
||||||
|
|
||||||
|
logger.info(f'The called model is: {model_name}.')
|
||||||
|
logger.info(f'The storage interval is: {interval}.')
|
||||||
|
|
||||||
|
file_lists = get_jsonl_file_paths() # 数据整合文件夹下所有.jsonl文件的地址
|
||||||
|
|
||||||
|
for file_path in file_lists:
|
||||||
|
# 一个jsonl文件的所有QA Pairs
|
||||||
|
contents = get_QA_pairs(file_path)
|
||||||
|
# print(contents)
|
||||||
|
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
print(file_name)
|
||||||
|
storage_jsonl_path = os.path.join(
|
||||||
|
clean_dir, f'{file_name}')
|
||||||
|
|
||||||
|
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
|
||||||
|
|
||||||
|
contents_array = np.array(contents)
|
||||||
|
chunks = np.array_split(contents_array, multi_process_num)
|
||||||
|
|
||||||
|
# 构建并发参数 list
|
||||||
|
parameters_list = list()
|
||||||
|
for thread_num, chunk in enumerate(chunks):
|
||||||
|
parameters_list.append(
|
||||||
|
[thread_num, interval, model_caller, storage_jsonl_path, list(chunk)]
|
||||||
|
)
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor:
|
||||||
|
# 循环调用 single_thread_generate 函数,每次赋予参数 parameters
|
||||||
|
futures = [executor.submit(single_thread_generate, *parameters) for parameters in parameters_list]
|
||||||
|
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
try:
|
||||||
|
future.result()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Thread generated an exception: %s" % (exc))
|
||||||
|
|
||||||
|
merge_sub_qa_generation(result_dir, storage_jsonl_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 创建washed文件夹
|
||||||
|
os.makedirs('./data/cleaned', exist_ok=True)
|
||||||
|
clean_qa(interval=storage_interval)
|
8
scripts/qa_generation/choose_prompt.md
Normal file
8
scripts/qa_generation/choose_prompt.md
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
你是一名经验丰富的心理咨询师,熟悉心理学相关知识。根据我提供的 QA 对,来判断这个 QA 对是否属于心理学范畴。
|
||||||
|
|
||||||
|
标准如下:
|
||||||
|
- 若当前 QA 对属于心理学范畴,则返回1
|
||||||
|
- 若当前 QA 对不属于心理学范畴,则返回0
|
||||||
|
|
||||||
|
|
||||||
|
以下是给定的心理学 QA 对内容:
|
@ -10,7 +10,9 @@ base_dir = os.path.dirname(cur_dir) # ba
|
|||||||
model_dir = os.path.join(base_dir, 'model') # model
|
model_dir = os.path.join(base_dir, 'model') # model
|
||||||
|
|
||||||
# data
|
# data
|
||||||
data_dir = os.path.join(base_dir, 'data') # data
|
data_dir = os.path.join(base_dir, 'data')
|
||||||
|
clean_dir = os.path.join(data_dir, 'cleaned')
|
||||||
|
judge_dir = os.path.join(data_dir, '数据整合')
|
||||||
result_dir = os.path.join(data_dir, 'generated') # result
|
result_dir = os.path.join(data_dir, 'generated') # result
|
||||||
|
|
||||||
# log
|
# log
|
||||||
@ -18,7 +20,9 @@ log_dir = os.path.join(base_dir, 'log') # lo
|
|||||||
log_file_path = os.path.join(log_dir, 'log.log') # file
|
log_file_path = os.path.join(log_dir, 'log.log') # file
|
||||||
|
|
||||||
# system prompt
|
# system prompt
|
||||||
|
# Prompt内容
|
||||||
system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # system prompt
|
system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # system prompt
|
||||||
|
wash_prompt_file_path = os.path.join(base_dir, 'choose_prompt.md')
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -28,7 +32,6 @@ system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # sy
|
|||||||
DASHSCOPE_API_KEY = ''
|
DASHSCOPE_API_KEY = ''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
控制参数
|
控制参数
|
||||||
"""
|
"""
|
||||||
@ -36,3 +39,4 @@ storage_interval = 10
|
|||||||
window_size = 8
|
window_size = 8
|
||||||
overlap_size = 2
|
overlap_size = 2
|
||||||
multi_process_num = 3
|
multi_process_num = 3
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from dashscope.api_entities.dashscope_response import Role
|
|||||||
|
|
||||||
from config.config import DASHSCOPE_API_KEY
|
from config.config import DASHSCOPE_API_KEY
|
||||||
from util.logger import get_logger
|
from util.logger import get_logger
|
||||||
from util.prompt_loader import load_system_prompt
|
from util.prompt_loader import load_system_prompt, load_wash_prompt
|
||||||
|
|
||||||
|
|
||||||
dashscope.api_key = DASHSCOPE_API_KEY
|
dashscope.api_key = DASHSCOPE_API_KEY
|
||||||
@ -39,3 +39,31 @@ def call_qwen_single_turn(query: str) -> str:
|
|||||||
response.code, response.message
|
response.code, response.message
|
||||||
))
|
))
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def call_qwen_Psychology_QA_Pairs(query: str) -> str:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
'role': Role.SYSTEM,
|
||||||
|
'content': load_wash_prompt()
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'role': Role.USER,
|
||||||
|
'content': query
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = Generation.call(
|
||||||
|
model='qwen-max-1201',
|
||||||
|
messages=messages,
|
||||||
|
result_format='message',
|
||||||
|
stream=False,
|
||||||
|
incremental_output=False
|
||||||
|
)
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
return response.output.choices[0]['message']['content']
|
||||||
|
else:
|
||||||
|
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||||||
|
response.request_id, response.status_code,
|
||||||
|
response.code, response.message
|
||||||
|
))
|
||||||
|
return ""
|
||||||
|
@ -4,11 +4,39 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from config.config import data_dir
|
from config.config import data_dir, judge_dir
|
||||||
from util.logger import get_logger
|
from util.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
递归获取 数据整合 下的所有 .jsonl 文件列表
|
||||||
|
"""
|
||||||
|
def get_jsonl_file_paths() -> List[str]:
|
||||||
|
json_file_paths = []
|
||||||
|
|
||||||
|
# 遍历根目录及其所有子目录
|
||||||
|
for dirpath, dirnames, filenames in os.walk(judge_dir):
|
||||||
|
# 对每个文件进行检查
|
||||||
|
for filename in filenames:
|
||||||
|
# 使用正则表达式匹配以.jsonl结尾的文件名
|
||||||
|
if re.search(r'\.jsonl$', filename):
|
||||||
|
# 构建完整的文件路径并添加到列表中
|
||||||
|
json_file_path = os.path.join(dirpath, filename)
|
||||||
|
json_file_paths.append(json_file_path)
|
||||||
|
|
||||||
|
return json_file_paths
|
||||||
|
|
||||||
|
def get_QA_pairs(json_path):
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read().strip()
|
||||||
|
|
||||||
|
# 按照换行符分割字符串
|
||||||
|
QA_Pairs = content.split('\n')
|
||||||
|
|
||||||
|
return QA_Pairs
|
||||||
|
|
||||||
"""
|
"""
|
||||||
递归获取 data_dir 下的所有 .txt 文件列表
|
递归获取 data_dir 下的所有 .txt 文件列表
|
||||||
"""
|
"""
|
||||||
@ -47,7 +75,7 @@ def get_txt_content(
|
|||||||
res = []
|
res = []
|
||||||
sentences_amount = len(sentences)
|
sentences_amount = len(sentences)
|
||||||
start_index, end_index = 0, sentences_amount - window_size
|
start_index, end_index = 0, sentences_amount - window_size
|
||||||
## check length
|
# check length
|
||||||
if window_size < overlap_size:
|
if window_size < overlap_size:
|
||||||
logger.error("window_size must be greater than or equal to overlap_size")
|
logger.error("window_size must be greater than or equal to overlap_size")
|
||||||
return None
|
return None
|
||||||
@ -56,7 +84,7 @@ def get_txt_content(
|
|||||||
return ['\n'.join(sentences)]
|
return ['\n'.join(sentences)]
|
||||||
|
|
||||||
for i in range(start_index, end_index + 1, overlap_size):
|
for i in range(start_index, end_index + 1, overlap_size):
|
||||||
res.append('\n'.join(sentences[i : i + window_size]))
|
res.append('\n'.join(sentences[i: i + window_size]))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@ -80,6 +108,7 @@ def capture_qa(content: str) -> List[Dict]:
|
|||||||
logger.warning("No JSON block found.")
|
logger.warning("No JSON block found.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
将 storage_list 存入到 storage_jsonl_path
|
将 storage_list 存入到 storage_jsonl_path
|
||||||
"""
|
"""
|
||||||
@ -88,6 +117,7 @@ def save_to_file(storage_jsonl_path, storage_list):
|
|||||||
for item in storage_list:
|
for item in storage_list:
|
||||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
将并发产生的文件合并成为一个文件
|
将并发产生的文件合并成为一个文件
|
||||||
"""
|
"""
|
||||||
@ -103,4 +133,3 @@ def merge_sub_qa_generation(directory, storage_jsonl_path):
|
|||||||
file_contents.append(json.loads(line))
|
file_contents.append(json.loads(line))
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
save_to_file(storage_jsonl_path, file_contents)
|
save_to_file(storage_jsonl_path, file_contents)
|
||||||
|
|
||||||
|
@ -1,7 +1,14 @@
|
|||||||
from config.config import system_prompt_file_path
|
from config.config import system_prompt_file_path
|
||||||
|
from config.config import wash_prompt_file_path
|
||||||
|
|
||||||
|
|
||||||
def load_system_prompt() -> str:
|
def load_system_prompt() -> str:
|
||||||
with open(system_prompt_file_path, 'r', encoding='utf-8') as f:
|
with open(system_prompt_file_path, 'r', encoding='utf-8') as f:
|
||||||
system_prompt = f.read()
|
system_prompt = f.read()
|
||||||
return system_prompt
|
return system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def load_wash_prompt() -> str:
|
||||||
|
with open(wash_prompt_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
wash_prompt = f.read()
|
||||||
|
return wash_prompt
|
||||||
|
Loading…
Reference in New Issue
Block a user