Merge pull request #52 from Anooyman/main

Add concurrent functions
This commit is contained in:
xzw 2024-03-09 19:50:18 +08:00 committed by GitHub
commit 7891e1aa5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 6 deletions

View File

@ -35,3 +35,4 @@ DASHSCOPE_API_KEY = ''
storage_interval = 10 storage_interval = 10
window_size = 8 window_size = 8
overlap_size = 2 overlap_size = 2
multi_process_num = 3

View File

@ -1,9 +1,11 @@
import os import os
import json import json
import time
from tqdm import tqdm from tqdm import tqdm
import concurrent.futures
from datetime import datetime from datetime import datetime
from config.config import result_dir, storage_interval, window_size, overlap_size from config.config import result_dir, storage_interval, window_size, overlap_size, multi_process_num
from model.qwen import call_qwen_single_turn from model.qwen import call_qwen_single_turn
from util.logger import get_logger from util.logger import get_logger
from util.data_loader import get_file_list, get_txt_content, capture_qa from util.data_loader import get_file_list, get_txt_content, capture_qa
@ -42,15 +44,28 @@ def generate_qa(
storage_jsonl_path = os.path.join( storage_jsonl_path = os.path.join(
result_dir, f'{current_time}-{file_name}-{model_name}.jsonl') result_dir, f'{current_time}-{file_name}-{model_name}.jsonl')
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.') logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
for content in tqdm(contents): response_list = []
response = model_caller(content) logger.info("Begin generate QA data")
# print(response) # 打印输出 with tqdm(total=len(contents)) as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor:
# 创建一个Future列表它们将对应每个worker_function的结果
futures = {executor.submit(model_caller, content): content for content in contents}
for future in concurrent.futures.as_completed(futures):
try:
response_list.append(future.result())
pbar.update(1)
except Exception as exc:
logger.error("Item generated an exception: %s" % (exc))
logger.info("Begin capture LLM response")
for response in tqdm(response_list):
captured_qa = capture_qa(response) captured_qa = capture_qa(response)
# print(captured_qa) # 打印QA对 # print(captured_qa) # 打印QA对
if captured_qa is None: if captured_qa is None:
continue continue
storage_list.extend(captured_qa) storage_list.extend(captured_qa)
storage_counter += 1 storage_counter += 1
if storage_counter % interval == 0: if storage_counter % interval == 0: