 c01c67c33f
			
		
	
	
		c01c67c33f
		
	
	
	
	
		
			
			Each thread stores the generated content independently, and finally integrates the generated content into a file
		
			
				
	
	
		
			107 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			107 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import re
 | |
| import json
 | |
| import glob
 | |
| from typing import List, Dict
 | |
| 
 | |
| from config.config import data_dir
 | |
| from util.logger import get_logger
 | |
| 
 | |
| logger = get_logger()
 | |
| 
 | |
| """
 | |
| 递归获取 data_dir 下的所有 .txt 文件列表
 | |
| """
 | |
| def get_file_list() -> List[str]:
 | |
|     txt_files = []
 | |
|     txt_exist_flag = False
 | |
|     for root, dirs, files in os.walk(data_dir):
 | |
|         for file in files:
 | |
|             if file.endswith('.txt'):
 | |
|                 txt_exist_flag = True
 | |
|                 txt_files.append(os.path.join(root, file))
 | |
| 
 | |
|     if not txt_exist_flag:
 | |
|         logger.warning(f'No txt text found in {data_dir}, please check!')
 | |
|     return txt_files
 | |
| 
 | |
| """
 | |
| 获取 txt 文本的所有内容,按句子返回 List
 | |
| file_path: txt 文本路径
 | |
| window_size: 滑窗大小,单位为句子数
 | |
| overlap_size: 重叠大小,单位为句子数
 | |
| """
 | |
| def get_txt_content(
 | |
|     file_path: str,
 | |
|     window_size: int = 6,
 | |
|     overlap_size: int = 2
 | |
| ) -> List[str]:
 | |
|     with open(file_path, 'r', encoding='utf-8') as f:
 | |
|         content = f.read().strip()
 | |
| 
 | |
|     # 简单实现:按句号、感叹号、问号分割,并去除句内空白符
 | |
|     sentences = re.split(r'(?<=[。!?])\s+', content)
 | |
|     sentences = [s.replace(' ', '').replace('\t', '') for s in sentences]
 | |
| 
 | |
|     # 滑窗
 | |
|     res = []
 | |
|     sentences_amount = len(sentences)
 | |
|     start_index, end_index = 0, sentences_amount - window_size
 | |
|     ## check length
 | |
|     if window_size < overlap_size:
 | |
|         logger.error("window_size must be greater than or equal to overlap_size")
 | |
|         return None
 | |
|     if window_size >= sentences_amount:
 | |
|         logger.warning("window_size exceeds the amount of sentences, and the complete text content will be returned")
 | |
|         return ['\n'.join(sentences)]
 | |
|     
 | |
|     for i in range(start_index, end_index + 1, overlap_size):
 | |
|         res.append('\n'.join(sentences[i : i + window_size]))
 | |
|     return res
 | |
| 
 | |
| 
 | |
| """
 | |
| 提取返回的 QA 对
 | |
| """
 | |
| def capture_qa(content: str) -> List[Dict]:
 | |
|     # 只捕获第一个 json 块
 | |
|     match = re.search(r'```json(.*?)```', content, re.DOTALL)
 | |
| 
 | |
|     if match:
 | |
|         parsed_data = None
 | |
|         block = match.group(1)
 | |
|         try:
 | |
|             parsed_data = json.loads(block)
 | |
|         except:
 | |
|             logger.warning('Unable to parse JSON properly.')
 | |
|         finally:
 | |
|             return parsed_data
 | |
|     else:
 | |
|         logger.warning("No JSON block found.")
 | |
|         return None
 | |
| 
 | |
| """
 | |
| 将 storage_list 存入到 storage_jsonl_path
 | |
| """
 | |
| def save_to_file(storage_jsonl_path, storage_list):
 | |
|     with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
 | |
|         for item in storage_list:
 | |
|             f.write(json.dumps(item, ensure_ascii=False) + '\n')
 | |
| 
 | |
| """
 | |
| 将并发产生的文件合并成为一个文件
 | |
| """
 | |
| def merge_sub_qa_generation(directory, storage_jsonl_path):
 | |
| 
 | |
|     # 查找以指定前缀开始的所有文件
 | |
|     matching_files = glob.glob(os.path.join(directory, storage_jsonl_path + "*"))
 | |
|     
 | |
|     file_contents = []
 | |
|     for file_path in matching_files:
 | |
|         with open(file_path, 'r', encoding='utf-8') as f:
 | |
|             for line in f:
 | |
|                 file_contents.append(json.loads(line))
 | |
|             os.remove(file_path)
 | |
|     save_to_file(storage_jsonl_path, file_contents)
 | |
| 
 |