Merge pull request #52 from Anooyman/main

Add concurrent functions
This commit is contained in:
xzw 2024-03-09 19:50:18 +08:00 committed by GitHub
commit 7891e1aa5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 6 deletions

View File

@ -35,3 +35,4 @@ DASHSCOPE_API_KEY = ''
storage_interval = 10
window_size = 8
overlap_size = 2
multi_process_num = 3

View File

@ -1,9 +1,11 @@
import os
import json
import time
from tqdm import tqdm
import concurrent.futures
from datetime import datetime
from config.config import result_dir, storage_interval, window_size, overlap_size
from config.config import result_dir, storage_interval, window_size, overlap_size, multi_process_num
from model.qwen import call_qwen_single_turn
from util.logger import get_logger
from util.data_loader import get_file_list, get_txt_content, capture_qa
@ -42,15 +44,28 @@ def generate_qa(
storage_jsonl_path = os.path.join(
result_dir, f'{current_time}-{file_name}-{model_name}.jsonl')
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
for content in tqdm(contents):
response = model_caller(content)
# print(response) # 打印输出
response_list = []
logger.info("Begin generate QA data")
with tqdm(total=len(contents)) as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor:
# 创建一个Future列表它们将对应每个worker_function的结果
futures = {executor.submit(model_caller, content): content for content in contents}
for future in concurrent.futures.as_completed(futures):
try:
response_list.append(future.result())
pbar.update(1)
except Exception as exc:
logger.error("Item generated an exception: %s" % (exc))
logger.info("Begin capture LLM response")
for response in tqdm(response_list):
captured_qa = capture_qa(response)
# print(captured_qa) # 打印QA对
if captured_qa is None:
continue
storage_list.extend(captured_qa)
storage_counter += 1
if storage_counter % interval == 0: