OliveSensorAPI/scripts/qa_generation/main.py

76 lines
2.5 KiB
Python
Raw Normal View History

2024-03-07 17:56:07 +08:00
import os
import json
from tqdm import tqdm
from datetime import datetime
from config.config import result_dir
from model.qwen import call_qwen_single_turn
from util.logger import get_logger
from util.data_loader import get_file_list, get_txt_content, capture_qa
logger = get_logger()
"""
生成 QA
model_name: 可调用的模型名称暂时只实现了 qwen
interval: 存储间隔即每隔多少条存一次文件过密的间隔会增大 IO 开销
"""
def generate_qa(
model_name: str = 'qwen',
interval: int = 10,
2024-03-07 17:56:07 +08:00
):
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if model_name == 'qwen':
model_caller = call_qwen_single_turn
else:
logger.warning('This model is currently not supported and will call the default model - qwen.')
model_caller = call_qwen_single_turn
model_name = 'qwen'
logger.info(f'The called model is: {model_name}.')
logger.info(f'The storage interval is: {interval}.')
file_list = get_file_list()
storage_counter = 0
storage_list = []
2024-03-08 14:25:34 +08:00
for file_path in file_list:
contents = get_txt_content(file_path)
2024-03-07 17:56:07 +08:00
storage_list = []
2024-03-08 14:25:34 +08:00
_, file_name = os.path.split(file_path)
storage_jsonl_path = os.path.join(
result_dir, f'{current_time}-{file_name}-{model_name}.jsonl')
2024-03-07 17:56:07 +08:00
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
for content in tqdm(contents):
response = model_caller(content)
2024-03-08 14:25:34 +08:00
# print(response) # 打印输出
2024-03-07 17:56:07 +08:00
captured_qa = capture_qa(response)
2024-03-08 14:25:34 +08:00
# print(captured_qa) # 打印QA对
2024-03-07 17:56:07 +08:00
if captured_qa is None:
continue
storage_list.extend(captured_qa)
storage_counter += 1
if storage_counter % interval == 0:
storage_counter = 0
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
for item in storage_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
storage_list = []
# 如果有剩余,存入
if storage_list:
with open(storage_jsonl_path, 'a', encoding='utf-8') as f:
for item in storage_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
storage_list = []
if __name__ == '__main__':
2024-03-08 14:25:34 +08:00
# 创建generated文件夹
os.makedirs('./data/generated', exist_ok=True)
2024-03-07 17:56:07 +08:00
generate_qa()