Add concurrent functions (#1)
Add concurrent functions for QA generation Co-authored-by: edward_ke <edward_ke@trendmicro.com>
This commit is contained in:
		
							parent
							
								
									ff6751a639
								
							
						
					
					
						commit
						d60f1dc8e1
					
				| @ -35,3 +35,4 @@ DASHSCOPE_API_KEY = '' | |||||||
| storage_interval = 10 | storage_interval = 10 | ||||||
| window_size = 8 | window_size = 8 | ||||||
| overlap_size = 2 | overlap_size = 2 | ||||||
|  | multi_process_num = 3 | ||||||
|  | |||||||
| @ -1,9 +1,11 @@ | |||||||
| import os | import os | ||||||
| import json | import json | ||||||
|  | import time | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
|  | import concurrent.futures | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| 
 | 
 | ||||||
| from config.config import result_dir, storage_interval, window_size, overlap_size | from config.config import result_dir, storage_interval, window_size, overlap_size, multi_process_num | ||||||
| from model.qwen import call_qwen_single_turn | from model.qwen import call_qwen_single_turn | ||||||
| from util.logger import get_logger | from util.logger import get_logger | ||||||
| from util.data_loader import get_file_list, get_txt_content, capture_qa | from util.data_loader import get_file_list, get_txt_content, capture_qa | ||||||
| @ -42,15 +44,28 @@ def generate_qa( | |||||||
|         storage_jsonl_path = os.path.join( |         storage_jsonl_path = os.path.join( | ||||||
|             result_dir, f'{current_time}-{file_name}-{model_name}.jsonl') |             result_dir, f'{current_time}-{file_name}-{model_name}.jsonl') | ||||||
|         logger.info(f'The generated QA will be stored in {storage_jsonl_path}.') |         logger.info(f'The generated QA will be stored in {storage_jsonl_path}.') | ||||||
|          | 
 | ||||||
|         for content in tqdm(contents): |         response_list = [] | ||||||
|             response = model_caller(content) |         logger.info("Begin generate QA data") | ||||||
|             # print(response) # 打印输出 |         with tqdm(total=len(contents)) as pbar: | ||||||
|  |             with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor: | ||||||
|  |                 # 创建一个Future列表,它们将对应每个worker_function的结果 | ||||||
|  |                 futures = {executor.submit(model_caller, content): content for content in contents} | ||||||
|  |              | ||||||
|  |                 for future in concurrent.futures.as_completed(futures): | ||||||
|  |                     try: | ||||||
|  |                         response_list.append(future.result()) | ||||||
|  |                         pbar.update(1) | ||||||
|  |                     except Exception as exc: | ||||||
|  |                         logger.error("Item generated an exception: %s" % (exc)) | ||||||
|  | 
 | ||||||
|  |         logger.info("Begin capture LLM response") | ||||||
|  |         for response in tqdm(response_list): | ||||||
|             captured_qa = capture_qa(response) |             captured_qa = capture_qa(response) | ||||||
|             # print(captured_qa) # 打印QA对 |             # print(captured_qa) # 打印QA对 | ||||||
|             if captured_qa is None: |             if captured_qa is None: | ||||||
|                 continue |                 continue | ||||||
|              | 
 | ||||||
|             storage_list.extend(captured_qa) |             storage_list.extend(captured_qa) | ||||||
|             storage_counter += 1 |             storage_counter += 1 | ||||||
|             if storage_counter % interval == 0: |             if storage_counter % interval == 0: | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Anooyman
						Anooyman