""" WebSocketManager — asyncio-side manager for WebSocket connections. All methods are coroutines and must be called from the asyncio event loop. No locking needed — the event loop is single-threaded. Subscription model: - Each connection subscribes to zero or more server_ids. - Subscribing to server_id=None means "all servers". - broadcast(server_id, message) sends to all clients subscribed to that server_id plus all clients subscribed to None (global subscribers). """ from __future__ import annotations import json import logging from typing import Optional from fastapi import WebSocket logger = logging.getLogger(__name__) class WebSocketManager: """Manages active WebSocket connections and delivers broadcast messages.""" def __init__(self) -> None: # Maps WebSocket -> set of subscribed server_ids (None = all) self._connections: dict[WebSocket, set[Optional[int]]] = {} # ── Connection lifecycle ── async def connect(self, ws: WebSocket, server_ids: Optional[list[int]] = None) -> None: """ Accept a WebSocket connection and register it. Args: ws: The FastAPI WebSocket instance. server_ids: List of server IDs to subscribe to, or None for all. """ await ws.accept() subscriptions: set[Optional[int]] = set(server_ids) if server_ids else {None} self._connections[ws] = subscriptions logger.info( "WebSocketManager: client connected, subscriptions=%s, total=%d", subscriptions, len(self._connections), ) async def disconnect(self, ws: WebSocket) -> None: """Remove a disconnected WebSocket.""" self._connections.pop(ws, None) logger.info( "WebSocketManager: client disconnected, total=%d", len(self._connections), ) # ── Broadcast ── async def broadcast(self, server_id: Optional[int], message: dict) -> None: """ Send a message to all clients subscribed to the given server_id. Also sends to clients subscribed to None (global subscribers). Disconnected clients are removed automatically. """ if not self._connections: return payload = json.dumps(message) disconnected = [] for ws, subscriptions in self._connections.items(): if None in subscriptions or server_id in subscriptions: try: await ws.send_text(payload) except Exception as exc: logger.debug("WebSocketManager: send failed, marking disconnected: %s", exc) disconnected.append(ws) for ws in disconnected: await self.disconnect(ws) async def send_to_connection(self, ws: WebSocket, message: dict) -> None: """Send a message to a single specific connection.""" try: await ws.send_text(json.dumps(message)) except Exception as exc: logger.debug("WebSocketManager: direct send failed, disconnecting: %s", exc) await self.disconnect(ws) # ── Stats ── @property def connection_count(self) -> int: return len(self._connections)