Update multi QA generation process
Each thread stores the generated content independently, and finally integrates the generated content into a file
This commit is contained in:
parent
e9c7ebf8df
commit
c01c67c33f
@ -4,14 +4,46 @@ import time
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from config.config import result_dir, storage_interval, window_size, overlap_size, multi_process_num
|
from config.config import result_dir, storage_interval, window_size, overlap_size, multi_process_num
|
||||||
from model.qwen import call_qwen_single_turn
|
from model.qwen import call_qwen_single_turn
|
||||||
from util.logger import get_logger
|
from util.logger import get_logger
|
||||||
from util.data_loader import get_file_list, get_txt_content, capture_qa
|
from util.data_loader import get_file_list, get_txt_content, capture_qa, merge_sub_qa_generation, save_to_file
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
"""
|
||||||
|
每个线程产生 QA 对以及存储到子文件中
|
||||||
|
"""
|
||||||
|
def single_thread_generate(thread_num, interval, model_caller, storage_jsonl_path, contents):
|
||||||
|
|
||||||
|
storage_counter = 0
|
||||||
|
storage_list = []
|
||||||
|
|
||||||
|
for content in tqdm(contents):
|
||||||
|
try:
|
||||||
|
response = model_caller(content)
|
||||||
|
captured_qa = capture_qa(response)
|
||||||
|
|
||||||
|
if captured_qa is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
storage_list.extend(captured_qa)
|
||||||
|
storage_counter += 1
|
||||||
|
|
||||||
|
if storage_counter % interval == 0:
|
||||||
|
save_to_file(storage_jsonl_path, storage_list)
|
||||||
|
storage_counter = 0
|
||||||
|
storage_list = []
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("QA generation error : %s" % (exc))
|
||||||
|
|
||||||
|
if storage_list:
|
||||||
|
save_to_file(storage_jsonl_path, storage_list)
|
||||||
|
storage_list = []
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
生成 QA 对
|
生成 QA 对
|
||||||
model_name: 可调用的模型名称,暂时只实现了 qwen
|
model_name: 可调用的模型名称,暂时只实现了 qwen
|
||||||
@ -45,43 +77,29 @@ def generate_qa(
|
|||||||
result_dir, f'{current_time}-{file_name}-{model_name}.jsonl')
|
result_dir, f'{current_time}-{file_name}-{model_name}.jsonl')
|
||||||
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
|
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
|
||||||
|
|
||||||
response_list = []
|
# 基于并发个数切分 contents 内容
|
||||||
logger.info("Begin generate QA data")
|
contents_array = np.array(contents)
|
||||||
with tqdm(total=len(contents)) as pbar:
|
chunks = np.array_split(contents_array, multi_process_num)
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor:
|
|
||||||
# 创建一个Future列表,它们将对应每个worker_function的结果
|
|
||||||
futures = {executor.submit(model_caller, content): content for content in contents}
|
|
||||||
|
|
||||||
for future in concurrent.futures.as_completed(futures):
|
# 构建并发参数 list
|
||||||
try:
|
parameters_list = list()
|
||||||
response_list.append(future.result())
|
for thread_num, chunk in enumerate(chunks):
|
||||||
pbar.update(1)
|
parameters_list.append(
|
||||||
except Exception as exc:
|
[thread_num, interval, model_caller, storage_jsonl_path + f'-{thread_num}', list(chunk)]
|
||||||
logger.error("Item generated an exception: %s" % (exc))
|
)
|
||||||
|
|
||||||
logger.info("Begin capture LLM response")
|
# 并发生成 QA 对
|
||||||
for response in tqdm(response_list):
|
with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor:
|
||||||
captured_qa = capture_qa(response)
|
# 创建一个Future列表,它们将对应每个worker_function的结果
|
||||||
# print(captured_qa) # 打印QA对
|
futures = [executor.submit(single_thread_generate, *parameters) for parameters in parameters_list]
|
||||||
if captured_qa is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
storage_list.extend(captured_qa)
|
for future in concurrent.futures.as_completed(futures):
|
||||||
storage_counter += 1
|
try:
|
||||||
if storage_counter % interval == 0:
|
future.result()
|
||||||
storage_counter = 0
|
except Exception as exc:
|
||||||
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
|
logger.error("Thread generated an exception: %s" % (exc))
|
||||||
for item in storage_list:
|
|
||||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
|
||||||
storage_list = []
|
|
||||||
|
|
||||||
# 如果有剩余,存入
|
|
||||||
if storage_list:
|
|
||||||
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
|
|
||||||
for item in storage_list:
|
|
||||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
|
||||||
storage_list = []
|
|
||||||
|
|
||||||
|
merge_sub_qa_generation(result_dir, storage_jsonl_path)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
|
import glob
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from config.config import data_dir
|
from config.config import data_dir
|
||||||
@ -78,3 +79,28 @@ def capture_qa(content: str) -> List[Dict]:
|
|||||||
else:
|
else:
|
||||||
logger.warning("No JSON block found.")
|
logger.warning("No JSON block found.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
"""
|
||||||
|
将 storage_list 存入到 storage_jsonl_path
|
||||||
|
"""
|
||||||
|
def save_to_file(storage_jsonl_path, storage_list):
|
||||||
|
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
|
||||||
|
for item in storage_list:
|
||||||
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
"""
|
||||||
|
将并发产生的文件合并成为一个文件
|
||||||
|
"""
|
||||||
|
def merge_sub_qa_generation(directory, storage_jsonl_path):
|
||||||
|
|
||||||
|
# 查找以指定前缀开始的所有文件
|
||||||
|
matching_files = glob.glob(os.path.join(directory, storage_jsonl_path + "*"))
|
||||||
|
|
||||||
|
file_contents = []
|
||||||
|
for file_path in matching_files:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
file_contents.append(json.loads(line))
|
||||||
|
os.remove(file_path)
|
||||||
|
save_to_file(storage_jsonl_path, file_contents)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user