olivebot/test/FunAudioLLM/SenseVoice/server.py

101 lines
3.2 KiB
Python
Raw Normal View History

import asyncio
import websockets
import argparse
import json
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
import os
# 设置日志级别
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL)
# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0", help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port", type=int, default=10197, help="grpc server port")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
args = parser.parse_args()
# 初始化模型
print("model loading")
model_dir = "iic/SenseVoiceSmall"
asr_model = AutoModel(
model=model_dir,
trust_remote_code=True,
remote_code="./model.py",
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000},
device="cuda:1",
)
print("model loaded")
websocket_users = {}
task_queue = asyncio.Queue()
async def ws_serve(websocket, path):
global websocket_users
user_id = id(websocket)
websocket_users[user_id] = websocket
try:
async for message in websocket:
if isinstance(message, str):
data = json.loads(message)
if 'url' in data:
await task_queue.put((websocket, data['url']))
except websockets.exceptions.ConnectionClosed as e:
logger.info(f"Connection closed: {e.reason}")
except Exception as e:
logger.error(f"Unexpected error: {e}")
finally:
logger.info(f"Cleaning up connection for user {user_id}")
if user_id in websocket_users:
del websocket_users[user_id]
await websocket.close()
logger.info("WebSocket closed")
async def worker():
while True:
websocket, url = await task_queue.get()
if websocket.open:
await process_wav_file(websocket, url)
else:
logger.info("WebSocket connection is already closed when trying to process file")
task_queue.task_done()
async def process_wav_file(websocket, url):
#热词
param_dict = {"sentence_timestamp": False}
with open("data/hotword.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
hotword = " ".join(lines)
print(f"热词:{hotword}")
param_dict["hotword"] = hotword
wav_path = url
try:
res = asr_model.generate(input=wav_path, cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=True, #
merge_length_s=15, **param_dict)
if res:
if 'text' in res[0] and websocket.open:
text = rich_transcription_postprocess(res[0]["text"])
await websocket.send(text)
except Exception as e:
print(f"Error during model.generate: {e}")
finally:
if os.path.exists(wav_path):
os.remove(wav_path)
async def main():
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=10)
await start_server
worker_task = asyncio.create_task(worker())
await worker_task
# 使用 asyncio 运行主函数
asyncio.run(main())