""" web/ws_bus.py ============= In-process WebSocket event bus. All connected web clients share a single WSBus instance. Events are delivered per-user (private results) or to all users (shared status). Usage:: bus = WSBus() # Subscribe (returns a queue; caller reads from it) q = bus.subscribe("alice") # Broadcast to all await bus.broadcast("status_snapshot", {...}) # Broadcast to one user (all their open tabs) await bus.broadcast_to_user("alice", "generation_complete", {...}) # Unsubscribe when WS disconnects bus.unsubscribe("alice", q) Event frame format sent on wire: {"type": "event_name", "data": {...}, "ts": 1234567890.123} """ from __future__ import annotations import asyncio import json import logging import time from typing import Any, Dict, Set logger = logging.getLogger(__name__) class WSBus: """ Per-user broadcast bus backed by asyncio queues. Thread-safe as long as all callers run in the same event loop. """ def __init__(self) -> None: # user_label → set of asyncio.Queue self._clients: Dict[str, Set[asyncio.Queue]] = {} # ------------------------------------------------------------------ # Subscription lifecycle # ------------------------------------------------------------------ def subscribe(self, user_label: str) -> asyncio.Queue: """Register a new client connection. Returns the queue to read from.""" q: asyncio.Queue = asyncio.Queue(maxsize=256) self._clients.setdefault(user_label, set()).add(q) logger.debug("WSBus: %s subscribed (%d queues)", user_label, len(self._clients[user_label])) return q def unsubscribe(self, user_label: str, queue: asyncio.Queue) -> None: """Remove a client connection.""" queues = self._clients.get(user_label, set()) queues.discard(queue) if not queues: self._clients.pop(user_label, None) logger.debug("WSBus: %s unsubscribed", user_label) @property def connected_users(self) -> list[str]: """List of user labels with at least one active connection.""" return list(self._clients.keys()) @property def total_connections(self) -> int: return sum(len(qs) for qs in self._clients.values()) # ------------------------------------------------------------------ # Broadcasting # ------------------------------------------------------------------ def _frame(self, event_type: str, data: Any) -> str: return json.dumps({"type": event_type, "data": data, "ts": time.time()}) async def broadcast(self, event_type: str, data: Any) -> None: """Send an event to ALL connected clients.""" frame = self._frame(event_type, data) for queues in list(self._clients.values()): for q in list(queues): try: q.put_nowait(frame) except asyncio.QueueFull: logger.warning("WSBus: queue full, dropping %s event", event_type) async def broadcast_to_user( self, user_label: str, event_type: str, data: Any ) -> None: """Send an event to all connections belonging to *user_label*.""" queues = self._clients.get(user_label, set()) if not queues: logger.debug("WSBus: no clients for user '%s', dropping %s", user_label, event_type) return frame = self._frame(event_type, data) for q in list(queues): try: q.put_nowait(frame) except asyncio.QueueFull: logger.warning("WSBus: queue full for %s, dropping %s", user_label, event_type) # Module-level singleton (set by web/app.py) _bus: WSBus | None = None def get_bus() -> WSBus: global _bus if _bus is None: _bus = WSBus() return _bus