Initial commit — ComfyUI Discord bot + web UI
Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI (React/TS/Vite/Tailwind), ComfyUI integration, generation history DB, preset manager, workflow inspector, and all supporting modules. Excluded from tracking: .env, invite_tokens.json, *.db (SQLite), current-workflow-changes.json, user_settings/, presets/, logs/, web-static/ (build output), frontend/node_modules/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
122
web/ws_bus.py
Normal file
122
web/ws_bus.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
web/ws_bus.py
|
||||
=============
|
||||
|
||||
In-process WebSocket event bus.
|
||||
|
||||
All connected web clients share a single WSBus instance. Events are
|
||||
delivered per-user (private results) or to all users (shared status).
|
||||
|
||||
Usage::
|
||||
|
||||
bus = WSBus()
|
||||
|
||||
# Subscribe (returns a queue; caller reads from it)
|
||||
q = bus.subscribe("alice")
|
||||
|
||||
# Broadcast to all
|
||||
await bus.broadcast("status_snapshot", {...})
|
||||
|
||||
# Broadcast to one user (all their open tabs)
|
||||
await bus.broadcast_to_user("alice", "generation_complete", {...})
|
||||
|
||||
# Unsubscribe when WS disconnects
|
||||
bus.unsubscribe("alice", q)
|
||||
|
||||
Event frame format sent on wire:
|
||||
{"type": "event_name", "data": {...}, "ts": 1234567890.123}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WSBus:
|
||||
"""
|
||||
Per-user broadcast bus backed by asyncio queues.
|
||||
|
||||
Thread-safe as long as all callers run in the same event loop.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# user_label → set of asyncio.Queue
|
||||
self._clients: Dict[str, Set[asyncio.Queue]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Subscription lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, user_label: str) -> asyncio.Queue:
|
||||
"""Register a new client connection. Returns the queue to read from."""
|
||||
q: asyncio.Queue = asyncio.Queue(maxsize=256)
|
||||
self._clients.setdefault(user_label, set()).add(q)
|
||||
logger.debug("WSBus: %s subscribed (%d queues)", user_label,
|
||||
len(self._clients[user_label]))
|
||||
return q
|
||||
|
||||
def unsubscribe(self, user_label: str, queue: asyncio.Queue) -> None:
|
||||
"""Remove a client connection."""
|
||||
queues = self._clients.get(user_label, set())
|
||||
queues.discard(queue)
|
||||
if not queues:
|
||||
self._clients.pop(user_label, None)
|
||||
logger.debug("WSBus: %s unsubscribed", user_label)
|
||||
|
||||
@property
|
||||
def connected_users(self) -> list[str]:
|
||||
"""List of user labels with at least one active connection."""
|
||||
return list(self._clients.keys())
|
||||
|
||||
@property
|
||||
def total_connections(self) -> int:
|
||||
return sum(len(qs) for qs in self._clients.values())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Broadcasting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _frame(self, event_type: str, data: Any) -> str:
|
||||
return json.dumps({"type": event_type, "data": data, "ts": time.time()})
|
||||
|
||||
async def broadcast(self, event_type: str, data: Any) -> None:
|
||||
"""Send an event to ALL connected clients."""
|
||||
frame = self._frame(event_type, data)
|
||||
for queues in list(self._clients.values()):
|
||||
for q in list(queues):
|
||||
try:
|
||||
q.put_nowait(frame)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("WSBus: queue full, dropping %s event", event_type)
|
||||
|
||||
async def broadcast_to_user(
|
||||
self, user_label: str, event_type: str, data: Any
|
||||
) -> None:
|
||||
"""Send an event to all connections belonging to *user_label*."""
|
||||
queues = self._clients.get(user_label, set())
|
||||
if not queues:
|
||||
logger.debug("WSBus: no clients for user '%s', dropping %s", user_label, event_type)
|
||||
return
|
||||
frame = self._frame(event_type, data)
|
||||
for q in list(queues):
|
||||
try:
|
||||
q.put_nowait(frame)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("WSBus: queue full for %s, dropping %s", user_label, event_type)
|
||||
|
||||
|
||||
# Module-level singleton (set by web/app.py)
|
||||
_bus: WSBus | None = None
|
||||
|
||||
|
||||
def get_bus() -> WSBus:
|
||||
global _bus
|
||||
if _bus is None:
|
||||
_bus = WSBus()
|
||||
return _bus
|
||||
Reference in New Issue
Block a user