2024-11-06 18:42:52 +08:00
|
|
|
import asyncio
|
|
|
|
import websockets
|
|
|
|
import socket
|
|
|
|
import threading
|
|
|
|
import time
|
2024-11-06 18:45:44 +08:00
|
|
|
import sys
|
2024-11-06 18:42:52 +08:00
|
|
|
|
|
|
|
__wss = None
|
|
|
|
|
|
|
|
def new_instance():
|
|
|
|
global __wss
|
|
|
|
if __wss is None:
|
|
|
|
__wss = SocketBridgeService()
|
|
|
|
return __wss
|
|
|
|
|
|
|
|
class SocketBridgeService:
|
|
|
|
def __init__(self):
|
2024-11-06 18:45:44 +08:00
|
|
|
self.websockets = {}
|
|
|
|
self.sockets = {}
|
|
|
|
self.message_queue = asyncio.Queue()
|
2024-11-06 18:42:52 +08:00
|
|
|
self.running = True
|
2024-11-06 18:45:44 +08:00
|
|
|
self.loop = None
|
|
|
|
self.tasks = set()
|
2024-11-06 18:42:52 +08:00
|
|
|
self.server = None
|
|
|
|
|
|
|
|
async def handler(self, websocket, path):
|
|
|
|
ws_id = id(websocket)
|
|
|
|
self.websockets[ws_id] = websocket
|
|
|
|
try:
|
|
|
|
if ws_id not in self.sockets:
|
2024-11-06 18:45:44 +08:00
|
|
|
sock = await self.create_socket_client()
|
|
|
|
if sock:
|
|
|
|
self.sockets[ws_id] = sock
|
|
|
|
else:
|
|
|
|
print(f"Failed to connect TCP socket for WebSocket {ws_id}")
|
|
|
|
await websocket.close()
|
|
|
|
return
|
|
|
|
receive_task = asyncio.create_task(self.receive_from_socket(ws_id))
|
|
|
|
self.tasks.add(receive_task)
|
|
|
|
receive_task.add_done_callback(self.tasks.discard)
|
2024-11-06 18:42:52 +08:00
|
|
|
async for message in websocket:
|
|
|
|
await self.send_to_socket(ws_id, message)
|
|
|
|
except websockets.ConnectionClosed:
|
|
|
|
pass
|
2024-11-06 18:45:44 +08:00
|
|
|
except Exception as e:
|
|
|
|
pass
|
2024-11-06 18:42:52 +08:00
|
|
|
finally:
|
|
|
|
self.close_socket_client(ws_id)
|
2024-11-06 18:45:44 +08:00
|
|
|
self.websockets.pop(ws_id, None)
|
2024-11-06 18:42:52 +08:00
|
|
|
|
|
|
|
async def create_socket_client(self):
|
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
2024-11-06 18:45:44 +08:00
|
|
|
try:
|
2024-12-11 15:29:38 +08:00
|
|
|
sock.connect(('127.0.0.1', 10001))
|
2024-11-06 18:45:44 +08:00
|
|
|
sock.setblocking(True) # 设置为阻塞模式
|
|
|
|
return sock
|
|
|
|
except Exception as e:
|
|
|
|
return None
|
2024-11-06 18:42:52 +08:00
|
|
|
|
|
|
|
async def send_to_socket(self, ws_id, message):
|
|
|
|
sock = self.sockets.get(ws_id)
|
|
|
|
if sock:
|
2024-11-06 18:45:44 +08:00
|
|
|
try:
|
|
|
|
await asyncio.to_thread(sock.sendall, message)
|
|
|
|
except Exception as e:
|
|
|
|
self.close_socket_client(ws_id)
|
2024-11-06 18:42:52 +08:00
|
|
|
|
|
|
|
async def receive_from_socket(self, ws_id):
|
|
|
|
sock = self.sockets.get(ws_id)
|
2024-11-06 18:45:44 +08:00
|
|
|
if not sock:
|
|
|
|
return
|
|
|
|
try:
|
|
|
|
while self.running:
|
|
|
|
data = await asyncio.to_thread(sock.recv, 4096)
|
|
|
|
if data:
|
|
|
|
await self.message_queue.put((ws_id, data))
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
|
|
|
finally:
|
|
|
|
self.close_socket_client(ws_id)
|
2024-11-06 18:42:52 +08:00
|
|
|
|
|
|
|
async def process_message_queue(self):
|
2024-11-06 18:45:44 +08:00
|
|
|
while self.running or not self.message_queue.empty():
|
|
|
|
try:
|
|
|
|
ws_id, data = await asyncio.wait_for(self.message_queue.get(), timeout=1.0)
|
|
|
|
websocket = self.websockets.get(ws_id)
|
|
|
|
if websocket and websocket.open:
|
|
|
|
try:
|
|
|
|
await websocket.send(data)
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
|
|
|
self.message_queue.task_done()
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
continue
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
2024-11-06 18:42:52 +08:00
|
|
|
|
|
|
|
def close_socket_client(self, ws_id):
|
|
|
|
sock = self.sockets.pop(ws_id, None)
|
|
|
|
if sock:
|
2024-11-06 18:45:44 +08:00
|
|
|
try:
|
|
|
|
sock.shutdown(socket.SHUT_RDWR)
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
|
|
|
# print(f"Error shutting down socket for WebSocket {ws_id}: {e}", file=sys.stderr)
|
2024-11-06 18:42:52 +08:00
|
|
|
sock.close()
|
|
|
|
|
|
|
|
async def start(self, host='0.0.0.0', port=9001):
|
2024-11-06 18:45:44 +08:00
|
|
|
self.server = await websockets.serve(self.handler, host, port)
|
|
|
|
process_task = asyncio.create_task(self.process_message_queue())
|
|
|
|
self.tasks.add(process_task)
|
|
|
|
process_task.add_done_callback(self.tasks.discard)
|
|
|
|
try:
|
|
|
|
await self.server.wait_closed()
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
pass
|
|
|
|
finally:
|
|
|
|
await self.shutdown()
|
2024-11-06 18:42:52 +08:00
|
|
|
|
|
|
|
async def shutdown(self):
|
2024-11-06 18:45:44 +08:00
|
|
|
if not self.running:
|
|
|
|
return
|
2024-11-06 18:42:52 +08:00
|
|
|
self.running = False
|
2024-11-06 18:45:44 +08:00
|
|
|
|
|
|
|
for ws_id, ws in list(self.websockets.items()):
|
|
|
|
try:
|
2024-11-06 18:42:52 +08:00
|
|
|
await ws.close()
|
2024-11-06 18:45:44 +08:00
|
|
|
except Exception as e:
|
|
|
|
pass
|
|
|
|
# print(f"Error closing WebSocket {ws_id}: {e}", file=sys.stderr)
|
|
|
|
self.websockets.clear()
|
|
|
|
|
|
|
|
for ws_id, sock in list(self.sockets.items()):
|
|
|
|
try:
|
|
|
|
sock.shutdown(socket.SHUT_RDWR)
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
|
|
|
# print(f"Error shutting down socket for WebSocket {ws_id}: {e}", file=sys.stderr)
|
2024-11-06 18:42:52 +08:00
|
|
|
sock.close()
|
2024-11-06 18:45:44 +08:00
|
|
|
self.sockets.clear()
|
|
|
|
|
|
|
|
await self.message_queue.join()
|
|
|
|
|
|
|
|
for task in self.tasks:
|
|
|
|
task.cancel()
|
|
|
|
await asyncio.gather(*self.tasks, return_exceptions=True)
|
|
|
|
self.tasks.clear()
|
|
|
|
|
2024-11-06 18:42:52 +08:00
|
|
|
if self.server:
|
2024-11-06 18:45:44 +08:00
|
|
|
self.server.close()
|
2024-11-06 18:42:52 +08:00
|
|
|
await self.server.wait_closed()
|
|
|
|
|
|
|
|
|
|
|
|
def start_service(self):
|
2024-11-06 18:45:44 +08:00
|
|
|
self.loop = asyncio.new_event_loop()
|
|
|
|
asyncio.set_event_loop(self.loop)
|
2024-11-06 18:42:52 +08:00
|
|
|
try:
|
2024-11-06 18:45:44 +08:00
|
|
|
self.loop.run_until_complete(self.start(host='0.0.0.0', port=9001))
|
|
|
|
except Exception as e:
|
2024-11-06 18:42:52 +08:00
|
|
|
pass
|
2024-11-06 18:45:44 +08:00
|
|
|
# print(f"Service exception: {e}", file=sys.stderr)
|
2024-11-06 18:42:52 +08:00
|
|
|
finally:
|
2024-11-06 18:45:44 +08:00
|
|
|
self.loop.close()
|
2024-11-06 18:42:52 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
service = new_instance()
|
2024-11-06 18:45:44 +08:00
|
|
|
service_thread = threading.Thread(target=service.start_service, daemon=True)
|
2024-11-06 18:42:52 +08:00
|
|
|
service_thread.start()
|
|
|
|
|
|
|
|
try:
|
2024-11-06 18:45:44 +08:00
|
|
|
while True:
|
2024-11-06 18:42:52 +08:00
|
|
|
time.sleep(1)
|
|
|
|
except KeyboardInterrupt:
|
2024-11-06 18:45:44 +08:00
|
|
|
# 在服务的事件循环中运行 shutdown 协程
|
|
|
|
print("Initiating shutdown...")
|
|
|
|
if service.loop and service.loop.is_running():
|
|
|
|
future = asyncio.run_coroutine_threadsafe(service.shutdown(), service.loop)
|
|
|
|
try:
|
|
|
|
future.result() # 等待关闭完成
|
|
|
|
print("Shutdown coroutine completed.")
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Shutdown exception: {e}", file=sys.stderr)
|
|
|
|
service_thread.join()
|
|
|
|
print("Service has been shut down.")
|