diff --git a/scripts/qa_generation/main.py b/scripts/qa_generation/main.py index 342b6a9..d84187f 100644 --- a/scripts/qa_generation/main.py +++ b/scripts/qa_generation/main.py @@ -4,14 +4,46 @@ import time from tqdm import tqdm import concurrent.futures from datetime import datetime +import numpy as np from config.config import result_dir, storage_interval, window_size, overlap_size, multi_process_num from model.qwen import call_qwen_single_turn from util.logger import get_logger -from util.data_loader import get_file_list, get_txt_content, capture_qa +from util.data_loader import get_file_list, get_txt_content, capture_qa, merge_sub_qa_generation, save_to_file logger = get_logger() +""" +每个线程产生 QA 对以及存储到子文件中 +""" +def single_thread_generate(thread_num, interval, model_caller, storage_jsonl_path, contents): + + storage_counter = 0 + storage_list = [] + + for content in tqdm(contents): + try: + response = model_caller(content) + captured_qa = capture_qa(response) + + if captured_qa is None: + continue + + storage_list.extend(captured_qa) + storage_counter += 1 + + if storage_counter % interval == 0: + save_to_file(storage_jsonl_path, storage_list) + storage_counter = 0 + storage_list = [] + except Exception as exc: + logger.error("QA generation error : %s" % (exc)) + + if storage_list: + save_to_file(storage_jsonl_path, storage_list) + storage_list = [] + + """ 生成 QA 对 model_name: 可调用的模型名称,暂时只实现了 qwen @@ -45,43 +77,29 @@ def generate_qa( result_dir, f'{current_time}-{file_name}-{model_name}.jsonl') logger.info(f'The generated QA will be stored in {storage_jsonl_path}.') - response_list = [] - logger.info("Begin generate QA data") - with tqdm(total=len(contents)) as pbar: - with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor: - # 创建一个Future列表,它们将对应每个worker_function的结果 - futures = {executor.submit(model_caller, content): content for content in contents} - - for future in concurrent.futures.as_completed(futures): - try: - response_list.append(future.result()) - pbar.update(1) - except Exception as exc: - logger.error("Item generated an exception: %s" % (exc)) + # 基于并发个数切分 contents 内容 + contents_array = np.array(contents) + chunks = np.array_split(contents_array, multi_process_num) - logger.info("Begin capture LLM response") - for response in tqdm(response_list): - captured_qa = capture_qa(response) - # print(captured_qa) # 打印QA对 - if captured_qa is None: - continue + # 构建并发参数 list + parameters_list = list() + for thread_num, chunk in enumerate(chunks): + parameters_list.append( + [thread_num, interval, model_caller, storage_jsonl_path + f'-{thread_num}', list(chunk)] + ) - storage_list.extend(captured_qa) - storage_counter += 1 - if storage_counter % interval == 0: - storage_counter = 0 - with open(storage_jsonl_path, 'a', encoding='utf-8') as f: - for item in storage_list: - f.write(json.dumps(item, ensure_ascii=False) + '\n') - storage_list = [] - - # 如果有剩余,存入 - if storage_list: - with open(storage_jsonl_path, 'a', encoding='utf-8') as f: - for item in storage_list: - f.write(json.dumps(item, ensure_ascii=False) + '\n') - storage_list = [] + # 并发生成 QA 对 + with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor: + # 创建一个Future列表,它们将对应每个worker_function的结果 + futures = [executor.submit(single_thread_generate, *parameters) for parameters in parameters_list] + + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as exc: + logger.error("Thread generated an exception: %s" % (exc)) + merge_sub_qa_generation(result_dir, storage_jsonl_path) if __name__ == '__main__': diff --git a/scripts/qa_generation/util/data_loader.py b/scripts/qa_generation/util/data_loader.py index 0724fca..fdfbfa9 100644 --- a/scripts/qa_generation/util/data_loader.py +++ b/scripts/qa_generation/util/data_loader.py @@ -1,6 +1,7 @@ import os import re import json +import glob from typing import List, Dict from config.config import data_dir @@ -78,3 +79,28 @@ def capture_qa(content: str) -> List[Dict]: else: logger.warning("No JSON block found.") return None + +""" +将 storage_list 存入到 storage_jsonl_path +""" +def save_to_file(storage_jsonl_path, storage_list): + with open(storage_jsonl_path, 'a', encoding='utf-8') as f: + for item in storage_list: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + +""" +将并发产生的文件合并成为一个文件 +""" +def merge_sub_qa_generation(directory, storage_jsonl_path): + + # 查找以指定前缀开始的所有文件 + matching_files = glob.glob(os.path.join(directory, storage_jsonl_path + "*")) + + file_contents = [] + for file_path in matching_files: + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + file_contents.append(json.loads(line)) + os.remove(file_path) + save_to_file(storage_jsonl_path, file_contents) +