Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI (React/TS/Vite/Tailwind), ComfyUI integration, generation history DB, preset manager, workflow inspector, and all supporting modules. Excluded from tracking: .env, invite_tokens.json, *.db (SQLite), current-workflow-changes.json, user_settings/, presets/, logs/, web-static/ (build output), frontend/node_modules/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
"""
|
|
web/ws_bus.py
|
|
=============
|
|
|
|
In-process WebSocket event bus.
|
|
|
|
All connected web clients share a single WSBus instance. Events are
|
|
delivered per-user (private results) or to all users (shared status).
|
|
|
|
Usage::
|
|
|
|
bus = WSBus()
|
|
|
|
# Subscribe (returns a queue; caller reads from it)
|
|
q = bus.subscribe("alice")
|
|
|
|
# Broadcast to all
|
|
await bus.broadcast("status_snapshot", {...})
|
|
|
|
# Broadcast to one user (all their open tabs)
|
|
await bus.broadcast_to_user("alice", "generation_complete", {...})
|
|
|
|
# Unsubscribe when WS disconnects
|
|
bus.unsubscribe("alice", q)
|
|
|
|
Event frame format sent on wire:
|
|
{"type": "event_name", "data": {...}, "ts": 1234567890.123}
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from typing import Any, Dict, Set
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WSBus:
|
|
"""
|
|
Per-user broadcast bus backed by asyncio queues.
|
|
|
|
Thread-safe as long as all callers run in the same event loop.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
# user_label → set of asyncio.Queue
|
|
self._clients: Dict[str, Set[asyncio.Queue]] = {}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Subscription lifecycle
|
|
# ------------------------------------------------------------------
|
|
|
|
def subscribe(self, user_label: str) -> asyncio.Queue:
|
|
"""Register a new client connection. Returns the queue to read from."""
|
|
q: asyncio.Queue = asyncio.Queue(maxsize=256)
|
|
self._clients.setdefault(user_label, set()).add(q)
|
|
logger.debug("WSBus: %s subscribed (%d queues)", user_label,
|
|
len(self._clients[user_label]))
|
|
return q
|
|
|
|
def unsubscribe(self, user_label: str, queue: asyncio.Queue) -> None:
|
|
"""Remove a client connection."""
|
|
queues = self._clients.get(user_label, set())
|
|
queues.discard(queue)
|
|
if not queues:
|
|
self._clients.pop(user_label, None)
|
|
logger.debug("WSBus: %s unsubscribed", user_label)
|
|
|
|
@property
|
|
def connected_users(self) -> list[str]:
|
|
"""List of user labels with at least one active connection."""
|
|
return list(self._clients.keys())
|
|
|
|
@property
|
|
def total_connections(self) -> int:
|
|
return sum(len(qs) for qs in self._clients.values())
|
|
|
|
# ------------------------------------------------------------------
|
|
# Broadcasting
|
|
# ------------------------------------------------------------------
|
|
|
|
def _frame(self, event_type: str, data: Any) -> str:
|
|
return json.dumps({"type": event_type, "data": data, "ts": time.time()})
|
|
|
|
async def broadcast(self, event_type: str, data: Any) -> None:
|
|
"""Send an event to ALL connected clients."""
|
|
frame = self._frame(event_type, data)
|
|
for queues in list(self._clients.values()):
|
|
for q in list(queues):
|
|
try:
|
|
q.put_nowait(frame)
|
|
except asyncio.QueueFull:
|
|
logger.warning("WSBus: queue full, dropping %s event", event_type)
|
|
|
|
async def broadcast_to_user(
|
|
self, user_label: str, event_type: str, data: Any
|
|
) -> None:
|
|
"""Send an event to all connections belonging to *user_label*."""
|
|
queues = self._clients.get(user_label, set())
|
|
if not queues:
|
|
logger.debug("WSBus: no clients for user '%s', dropping %s", user_label, event_type)
|
|
return
|
|
frame = self._frame(event_type, data)
|
|
for q in list(queues):
|
|
try:
|
|
q.put_nowait(frame)
|
|
except asyncio.QueueFull:
|
|
logger.warning("WSBus: queue full for %s, dropping %s", user_label, event_type)
|
|
|
|
|
|
# Module-level singleton (set by web/app.py)
|
|
_bus: WSBus | None = None
|
|
|
|
|
|
def get_bus() -> WSBus:
|
|
global _bus
|
|
if _bus is None:
|
|
_bus = WSBus()
|
|
return _bus
|