Initial commit — ComfyUI Discord bot + web UI

Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI
(React/TS/Vite/Tailwind), ComfyUI integration, generation history DB,
preset manager, workflow inspector, and all supporting modules.

Excluded from tracking: .env, invite_tokens.json, *.db (SQLite),
current-workflow-changes.json, user_settings/, presets/, logs/,
web-static/ (build output), frontend/node_modules/.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Khoa (Revenovich) Tran Gia
2026-03-02 09:55:48 +07:00
commit 1ed3c9ec4b
82 changed files with 20693 additions and 0 deletions

1
web/__init__.py Normal file
View File

@@ -0,0 +1 @@
# web package

269
web/app.py Normal file
View File

@@ -0,0 +1,269 @@
"""
web/app.py
==========
FastAPI application factory.
The app is created once and shared between the Uvicorn server (started
from bot.py via asyncio.gather) and tests.
Startup tasks:
- Background status ticker (broadcasts status_snapshot every 5s to all clients)
- Background NSSM poll (broadcasts server_state every 10s to all clients)
"""
from __future__ import annotations
import asyncio
import logging
import mimetypes
import os
from pathlib import Path
from fastapi import FastAPI
from starlette.exceptions import HTTPException as _HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request as _Request
# Windows registry can map .js → text/plain; override to the correct types
# before StaticFiles reads them.
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/wasm", ".wasm")
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
class _NoCacheHTMLMiddleware(BaseHTTPMiddleware):
"""Force browsers to revalidate index.html on every request.
Vite hashes JS/CSS filenames on every build so those assets are
naturally cache-busted. index.html itself has a stable name, so
without an explicit Cache-Control header mobile browsers apply
heuristic caching and keep serving a stale copy after a redeploy.
"""
async def dispatch(self, request: _Request, call_next):
response = await call_next(request)
ct = response.headers.get("content-type", "")
if "text/html" in ct:
response.headers["Cache-Control"] = "no-cache, must-revalidate"
return response
class _SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to every response."""
_CSP = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: blob:; "
"connect-src 'self' wss:; "
"frame-ancestors 'none';"
)
async def dispatch(self, request: _Request, call_next):
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Content-Security-Policy"] = self._CSP
return response
class _SPAStaticFiles(StaticFiles):
"""StaticFiles with SPA fallback: serve index.html for unknown paths.
Starlette's html=True only serves index.html for directory requests.
This subclass additionally returns index.html for any path that has no
matching file, so client-side routes like /generate work on refresh.
"""
async def get_response(self, path: str, scope):
try:
return await super().get_response(path, scope)
except _HTTPException as ex:
if ex.status_code == 404:
return await super().get_response("index.html", scope)
raise
from web.ws_bus import get_bus
logger = logging.getLogger(__name__)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
_WEB_STATIC = _PROJECT_ROOT / "web-static"
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
app = FastAPI(
title="ComfyUI Bot Web UI",
version="1.0.0",
docs_url=None,
redoc_url=None,
openapi_url=None,
)
# CORS — only allow explicitly configured origins; empty = no cross-origin
_cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()]
if _cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security headers on every response
app.add_middleware(_SecurityHeadersMiddleware)
# Prevent browsers from caching index.html across deploys
app.add_middleware(_NoCacheHTMLMiddleware)
# Register API routers
from web.routers.auth_router import router as auth_router
from web.routers.admin_router import router as admin_router
from web.routers.status_router import router as status_router
from web.routers.state_router import router as state_router
from web.routers.generate_router import router as generate_router
from web.routers.inputs_router import router as inputs_router
from web.routers.presets_router import router as presets_router
from web.routers.server_router import router as server_router
from web.routers.history_router import router as history_router
from web.routers.share_router import router as share_router
from web.routers.workflow_router import router as workflow_router
from web.routers.ws_router import router as ws_router
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
app.include_router(status_router, prefix="/api", tags=["status"])
app.include_router(state_router, prefix="/api", tags=["state"])
app.include_router(generate_router, prefix="/api", tags=["generate"])
app.include_router(inputs_router, prefix="/api/inputs", tags=["inputs"])
app.include_router(presets_router, prefix="/api/presets", tags=["presets"])
app.include_router(server_router, prefix="/api", tags=["server"])
app.include_router(history_router, prefix="/api/history", tags=["history"])
app.include_router(share_router, prefix="/api/share", tags=["share"])
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
app.include_router(ws_router, tags=["ws"])
# Serve frontend static files (if built)
if _WEB_STATIC.exists() and any(_WEB_STATIC.iterdir()):
app.mount("/", _SPAStaticFiles(directory=str(_WEB_STATIC), html=True), name="static")
logger.info("Serving frontend from %s", _WEB_STATIC)
@app.on_event("startup")
async def _startup():
asyncio.create_task(_status_ticker())
asyncio.create_task(_server_state_poller())
logger.info("Web background tasks started")
return app
# ---------------------------------------------------------------------------
# Background tasks
# ---------------------------------------------------------------------------
async def _status_ticker() -> None:
"""Broadcast status_snapshot to all clients every 5 seconds."""
from web.deps import get_bot, get_comfy, get_config
bus = get_bus()
while True:
await asyncio.sleep(5)
try:
bot = get_bot()
comfy = get_comfy()
config = get_config()
snapshot: dict = {}
if bot is not None:
lat = bot.latency
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
import datetime as _dt
start = getattr(bot, "start_time", None)
uptime = ""
if start:
delta = _dt.datetime.now(_dt.timezone.utc) - start
total = int(delta.total_seconds())
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
snapshot["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
if comfy is not None:
q = await comfy.get_comfy_queue()
pending = len(q.get("queue_pending", [])) if q else 0
running = len(q.get("queue_running", [])) if q else 0
wm = getattr(comfy, "workflow_manager", None)
wf_loaded = wm is not None and wm.get_workflow_template() is not None
snapshot["comfy"] = {
"server": comfy.server_address,
"queue_pending": pending,
"queue_running": running,
"workflow_loaded": wf_loaded,
"last_seed": comfy.last_seed,
"total_generated": comfy.total_generated,
}
if config is not None:
from media_uploader import get_stats as get_upload_stats, is_running as upload_running
try:
us = get_upload_stats()
snapshot["upload"] = {
"configured": bool(config.media_upload_user),
"running": upload_running(),
"total_ok": us.total_ok,
"total_fail": us.total_fail,
}
except Exception:
pass
from web.deps import get_user_registry
registry = get_user_registry()
connected = bus.connected_users
if connected and registry:
for ul in connected:
user_overrides = registry.get_state_manager(ul).get_overrides()
await bus.broadcast_to_user(ul, "status_snapshot", {**snapshot, "overrides": user_overrides})
else:
await bus.broadcast("status_snapshot", snapshot)
except Exception as exc:
logger.debug("Status ticker error: %s", exc)
async def _server_state_poller() -> None:
"""Poll NSSM service state and broadcast server_state every 10 seconds."""
from web.deps import get_config
bus = get_bus()
while True:
await asyncio.sleep(10)
try:
config = get_config()
if config is None:
continue
from commands.server import get_service_state
from web.deps import get_comfy
async def _false():
return False
comfy = get_comfy()
service_state, http_reachable = await asyncio.gather(
get_service_state(config.comfy_service_name),
comfy.check_connection() if comfy else _false(),
)
await bus.broadcast("server_state", {
"state": service_state,
"http_reachable": http_reachable,
})
except Exception as exc:
logger.debug("Server state poller error: %s", exc)

119
web/auth.py Normal file
View File

@@ -0,0 +1,119 @@
"""
web/auth.py
===========
JWT authentication for the web UI.
Flow:
- POST /api/auth/login {token} → verify invite token → issue JWT in httpOnly cookie
- All /api/* require valid JWT via require_auth dependency
- POST /api/admin/login {password} → issue admin JWT (admin: true claim)
- WS /ws?token=<jwt> → authenticate via query param
JWT claims: {"sub": "<label>", "admin": bool, "exp": ...}
"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import Cookie, Depends, HTTPException, status
from fastapi.security import HTTPBearer
try:
from jose import JWTError, jwt
except ImportError:
jwt = None # type: ignore
JWTError = Exception # type: ignore
logger = logging.getLogger(__name__)
ALGORITHM = "HS256"
_COOKIE_NAME = "ttb_session"
def _get_secret() -> str:
from web.deps import get_config
cfg = get_config()
if cfg and cfg.web_secret_key:
return cfg.web_secret_key
raise RuntimeError(
"WEB_SECRET_KEY must be set in the environment — "
"refusing to run with an insecure default."
)
def create_jwt(label: str, *, admin: bool = False, expire_hours: int = 8) -> str:
"""Create a signed JWT for the given user label."""
if jwt is None:
raise RuntimeError("python-jose is not installed (pip install python-jose[cryptography])")
expire = datetime.now(timezone.utc) + timedelta(hours=expire_hours)
payload = {"sub": label, "admin": admin, "exp": expire}
return jwt.encode(payload, _get_secret(), algorithm=ALGORITHM)
def decode_jwt(token: str) -> Optional[dict]:
"""Decode and verify a JWT. Returns the payload or None on failure."""
if jwt is None:
return None
try:
return jwt.decode(token, _get_secret(), algorithms=[ALGORITHM])
except JWTError as exc:
logger.debug("JWT decode failed: %s", exc)
return None
def verify_ws_token(token: str) -> Optional[dict]:
"""Verify a JWT passed as a WebSocket query parameter."""
return decode_jwt(token)
# ---------------------------------------------------------------------------
# FastAPI dependencies
# ---------------------------------------------------------------------------
def require_auth(ttb_session: Optional[str] = Cookie(default=None)) -> dict:
"""
FastAPI dependency that requires a valid JWT cookie.
Returns
-------
dict
The decoded JWT payload (``sub``, ``admin`` fields).
Raises
------
HTTPException 401
If the cookie is absent or the token is invalid/expired.
"""
if not ttb_session:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
)
payload = decode_jwt(ttb_session)
if payload is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
return payload
def require_admin(user: dict = Depends(require_auth)) -> dict:
"""
FastAPI dependency that requires an admin JWT.
Raises
------
HTTPException 403
If the token is valid but not admin.
"""
if not user.get("admin"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin access required",
)
return user

72
web/deps.py Normal file
View File

@@ -0,0 +1,72 @@
"""
web/deps.py
===========
Shared bot reference for FastAPI dependency injection.
``set_bot()`` is called once from ``bot.py`` before starting Uvicorn.
FastAPI route handlers use ``get_bot()``, ``get_comfy()``, etc. as
Depends() callables.
"""
from __future__ import annotations
from typing import Optional
_bot = None
def set_bot(bot) -> None:
"""Store the discord.py bot instance for DI access."""
global _bot
_bot = bot
def get_bot():
"""FastAPI dependency: return the bot instance."""
return _bot
def get_comfy():
"""FastAPI dependency: return the ComfyClient."""
if _bot is None:
return None
return getattr(_bot, "comfy", None)
def get_config():
"""FastAPI dependency: return the BotConfig."""
if _bot is None:
return None
return getattr(_bot, "config", None)
def get_state_manager():
"""FastAPI dependency: return the WorkflowStateManager."""
comfy = get_comfy()
if comfy is None:
return None
return getattr(comfy, "state_manager", None)
def get_workflow_manager():
"""FastAPI dependency: return the WorkflowManager."""
comfy = get_comfy()
if comfy is None:
return None
return getattr(comfy, "workflow_manager", None)
def get_inspector():
"""FastAPI dependency: return the WorkflowInspector."""
comfy = get_comfy()
if comfy is None:
return None
return getattr(comfy, "inspector", None)
def get_user_registry():
"""FastAPI dependency: return the UserStateRegistry."""
if _bot is None:
return None
return getattr(_bot, "user_registry", None)

146
web/login_guard.py Normal file
View File

@@ -0,0 +1,146 @@
"""
web/login_guard.py
==================
IP-based brute-force protection for login endpoints.
Tracks failed login attempts per IP in a rolling time window and issues a
temporary ban when the threshold is exceeded. Uses only stdlib — no new
pip packages required.
Usage
-----
from web.login_guard import get_guard, get_real_ip
@router.post("/login")
async def login(request: Request, body: LoginRequest, response: Response):
ip = get_real_ip(request)
get_guard().check(ip) # raises 429 if locked out
...
if failure:
get_guard().record_failure(ip)
raise HTTPException(401, ...)
get_guard().record_success(ip)
...
"""
from __future__ import annotations
import logging
import time
from collections import defaultdict
from typing import Dict, List
from fastapi import HTTPException, Request, status
logger = logging.getLogger(__name__)
def get_real_ip(request: Request) -> str:
"""Return the real client IP, honouring Cloudflare and common proxy headers.
Priority:
1. ``CF-Connecting-IP`` (set by Cloudflare)
2. ``X-Real-IP`` (set by nginx/traefik)
3. ``request.client.host`` (direct connection fallback)
"""
cf_ip = request.headers.get("CF-Connecting-IP", "").strip()
if cf_ip:
return cf_ip
real_ip = request.headers.get("X-Real-IP", "").strip()
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
class BruteForceGuard:
"""Rolling-window failure counter with automatic IP bans.
All state is in-process memory. A restart clears all bans and counters,
which is acceptable — a brief restart already provides a natural backoff
for a legitimate attacker.
"""
WINDOW_SECS = 600 # rolling window: 10 minutes
MAX_FAILURES = 10 # max failures before ban
BAN_SECS = 3600 # ban duration: 1 hour
def __init__(self) -> None:
# ip → list of failure timestamps (epoch floats)
self._failures: Dict[str, List[float]] = defaultdict(list)
# ip → ban expiry timestamp
self._ban_until: Dict[str, float] = {}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def check(self, ip: str) -> None:
"""Raise HTTP 429 if the IP is banned or has exceeded the failure threshold.
Call this *before* doing any credential work so the lockout is
evaluated even when the request body is malformed.
"""
now = time.time()
# Active ban?
ban_expiry = self._ban_until.get(ip, 0)
if ban_expiry > now:
logger.warning("login_guard: blocked request from banned ip=%s (ban expires in %.0fs)", ip, ban_expiry - now)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many failed attempts. Try again later.",
)
# Failure count within rolling window
cutoff = now - self.WINDOW_SECS
recent = [t for t in self._failures[ip] if t > cutoff]
self._failures[ip] = recent # prune stale entries while we're here
if len(recent) >= self.MAX_FAILURES:
# Threshold just reached — apply ban now
self._ban_until[ip] = now + self.BAN_SECS
logger.warning(
"login_guard: threshold reached, banning ip=%s for %ds",
ip, self.BAN_SECS,
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many failed attempts. Try again later.",
)
def record_failure(self, ip: str) -> None:
"""Record a failed login attempt for the given IP."""
now = time.time()
cutoff = now - self.WINDOW_SECS
recent = [t for t in self._failures[ip] if t > cutoff]
recent.append(now)
self._failures[ip] = recent
count = len(recent)
logger.warning("login_guard: failure #%d from ip=%s", count, ip)
if count >= self.MAX_FAILURES:
self._ban_until[ip] = now + self.BAN_SECS
logger.warning(
"login_guard: threshold reached, banning ip=%s for %ds",
ip, self.BAN_SECS,
)
def record_success(self, ip: str) -> None:
"""Clear failure history and any active ban for the given IP."""
self._failures.pop(ip, None)
self._ban_until.pop(ip, None)
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_guard: BruteForceGuard | None = None
def get_guard() -> BruteForceGuard:
"""Return the shared BruteForceGuard singleton (created on first call)."""
global _guard
if _guard is None:
_guard = BruteForceGuard()
return _guard

1
web/routers/__init__.py Normal file
View File

@@ -0,0 +1 @@
# web.routers package

View File

@@ -0,0 +1,88 @@
"""POST /api/admin/login; GET/POST/DELETE /api/admin/tokens"""
from __future__ import annotations
import hmac
import logging
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel
from web.auth import create_jwt, require_admin
from web.deps import get_config
from web.login_guard import get_guard, get_real_ip
router = APIRouter()
_COOKIE = "ttb_session"
audit = logging.getLogger("audit")
class AdminLoginRequest(BaseModel):
password: str
class CreateTokenRequest(BaseModel):
label: str
admin: bool = False
@router.post("/login")
async def admin_login(body: AdminLoginRequest, request: Request, response: Response):
"""Admin password login → admin JWT cookie."""
config = get_config()
expected_pw = config.admin_password if config else None
ip = get_real_ip(request)
get_guard().check(ip)
# Constant-time comparison to prevent timing attacks
if not expected_pw or not hmac.compare_digest(body.password, expected_pw):
get_guard().record_failure(ip)
audit.info("admin.login ip=%s success=False", ip)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Wrong password")
get_guard().record_success(ip)
audit.info("admin.login ip=%s success=True", ip)
expire_hours = config.web_jwt_expire_hours if config else 8
jwt_token = create_jwt("admin", admin=True, expire_hours=expire_hours)
response.set_cookie(
_COOKIE, jwt_token,
httponly=True, secure=True, samesite="strict",
max_age=expire_hours * 3600,
)
return {"label": "admin", "admin": True}
@router.get("/tokens")
async def list_tokens(_: dict = Depends(require_admin)):
"""List all invite tokens (hashes shown, labels safe)."""
from token_store import list_tokens as _list
config = get_config()
token_file = config.web_token_file if config else "invite_tokens.json"
records = _list(token_file)
# Don't return hashes to the UI
return [{"id": r["id"], "label": r["label"], "admin": r.get("admin", False),
"created_at": r.get("created_at")} for r in records]
@router.post("/tokens")
async def create_token(body: CreateTokenRequest, _: dict = Depends(require_admin)):
"""Create a new invite token. Returns the plaintext token (shown once)."""
from token_store import create_token as _create
config = get_config()
token_file = config.web_token_file if config else "invite_tokens.json"
plaintext = _create(body.label, token_file, admin=body.admin)
return {"token": plaintext, "label": body.label, "admin": body.admin}
@router.delete("/tokens/{token_id}")
async def revoke_token(token_id: str, _: dict = Depends(require_admin)):
"""Revoke an invite token by ID."""
from token_store import revoke_token as _revoke
config = get_config()
token_file = config.web_token_file if config else "invite_tokens.json"
ok = _revoke(token_id, token_file)
if not ok:
raise HTTPException(status_code=404, detail="Token not found")
return {"ok": True}

View File

@@ -0,0 +1,64 @@
"""POST /api/auth/login|logout; GET /api/auth/me"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel
from web.auth import create_jwt, require_auth
from web.deps import get_config
from web.login_guard import get_guard, get_real_ip
router = APIRouter()
_COOKIE = "ttb_session"
audit = logging.getLogger("audit")
class LoginRequest(BaseModel):
token: str
@router.post("/login")
async def login(body: LoginRequest, request: Request, response: Response):
"""Exchange an invite token for a JWT session cookie."""
from token_store import verify_token
config = get_config()
token_file = config.web_token_file if config else "invite_tokens.json"
expire_hours = config.web_jwt_expire_hours if config else 8
ip = get_real_ip(request)
get_guard().check(ip)
record = verify_token(body.token, token_file)
if record is None:
get_guard().record_failure(ip)
audit.info("auth.login ip=%s success=False", ip)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
label: str = record["label"]
admin: bool = record.get("admin", False)
get_guard().record_success(ip)
audit.info("auth.login ip=%s success=True label=%s", ip, label)
jwt_token = create_jwt(label, admin=admin, expire_hours=expire_hours)
response.set_cookie(
_COOKIE, jwt_token,
httponly=True, secure=True, samesite="strict",
max_age=expire_hours * 3600,
)
return {"label": label, "admin": admin}
@router.post("/logout")
async def logout(response: Response):
"""Clear the session cookie."""
response.delete_cookie(_COOKIE)
return {"ok": True}
@router.get("/me")
async def me(user: dict = Depends(require_auth)):
"""Return current user info."""
return {"label": user["sub"], "admin": user.get("admin", False)}

View File

@@ -0,0 +1,255 @@
"""POST /api/generate and /api/workflow-gen"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from web.auth import require_auth
from web.deps import get_comfy, get_config, get_user_registry
from web.ws_bus import get_bus
router = APIRouter()
logger = logging.getLogger(__name__)
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
overrides: Optional[Dict[str, Any]] = None # extra per-request overrides
class WorkflowGenRequest(BaseModel):
count: int = 1
overrides: Optional[Dict[str, Any]] = None # per-request overrides (merged with state)
@router.post("/generate")
async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
"""Submit a prompt-based generation to ComfyUI."""
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI client not available")
user_label: str = user["sub"]
bus = get_bus()
registry = get_user_registry()
# Temporary seed override from request
if body.overrides and "seed" in body.overrides:
seed_override = body.overrides["seed"]
elif registry:
seed_override = registry.get_state_manager(user_label).get_seed()
else:
seed_override = comfy.state_manager.get_seed()
overrides_for_gen = {"prompt": body.prompt}
if body.negative_prompt:
overrides_for_gen["negative_prompt"] = body.negative_prompt
if seed_override is not None:
overrides_for_gen["seed"] = seed_override
# Also apply any extra per-request overrides
if body.overrides:
overrides_for_gen.update(body.overrides)
# Get queue position estimate
depth = await comfy.get_queue_depth()
# Start generation as background task so we can return the prompt_id immediately
prompt_id_holder: list = []
async def _run():
# Use the user's own workflow template
if registry:
template = registry.get_workflow_template(user_label)
else:
template = comfy.workflow_manager.get_workflow_template()
if not template:
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": "No workflow template loaded"
})
return
import uuid
pid = str(uuid.uuid4())
prompt_id_holder.append(pid)
def on_progress(node, pid_):
asyncio.create_task(bus.broadcast("node_executing", {
"node": node, "prompt_id": pid_
}))
workflow, applied = comfy.inspector.inject_overrides(template, overrides_for_gen)
seed_used = applied.get("seed")
comfy.last_seed = seed_used
try:
images, videos = await comfy._general_generate(workflow, pid, on_progress)
except Exception as exc:
logger.exception("Generation error for prompt %s", pid)
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": pid, "error": str(exc)
})
return
comfy.last_prompt_id = pid
comfy.total_generated += 1
# Persist to DB before flush_pending deletes local files
config = get_config()
try:
from generation_db import record_generation, record_file
gen_id = record_generation(pid, "web", user_label, overrides_for_gen, seed_used)
for i, img_data in enumerate(images):
record_file(gen_id, f"image_{i:04d}.png", img_data)
if config and videos:
for vid in videos:
vsub = vid.get("video_subfolder", "")
vname = vid.get("video_name", "")
vpath = (
Path(config.comfy_output_path) / vsub / vname
if vsub
else Path(config.comfy_output_path) / vname
)
try:
record_file(gen_id, vname, vpath.read_bytes())
except OSError:
pass
except Exception as exc:
logger.warning("Failed to record generation to DB: %s", exc)
# Flush auto-upload
if config:
from media_uploader import flush_pending
asyncio.create_task(flush_pending(
Path(config.comfy_output_path),
config.media_upload_user,
config.media_upload_pass,
))
await bus.broadcast("queue_update", {
"prompt_id": pid,
"status": "complete",
})
await bus.broadcast_to_user(user_label, "generation_complete", {
"prompt_id": pid,
"seed": seed_used,
"image_count": len(images),
"video_count": len(videos),
})
asyncio.create_task(_run())
return {
"queued": True,
"queue_position": depth + 1,
"message": "Generation submitted to ComfyUI",
}
@router.post("/workflow-gen")
async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_auth)):
"""Submit workflow-based generation(s) to ComfyUI."""
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI client not available")
user_label: str = user["sub"]
bus = get_bus()
registry = get_user_registry()
count = max(1, min(body.count, 20)) # cap at 20
async def _run_one():
# Use the user's own state and template
if registry:
user_sm = registry.get_state_manager(user_label)
user_template = registry.get_workflow_template(user_label)
else:
user_sm = comfy.state_manager
user_template = comfy.workflow_manager.get_workflow_template()
if not user_template:
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": "No workflow template loaded"
})
return
overrides = user_sm.get_overrides()
if body.overrides:
overrides = {**overrides, **body.overrides}
import uuid
pid = str(uuid.uuid4())
def on_progress(node, pid_):
asyncio.create_task(bus.broadcast("node_executing", {
"node": node, "prompt_id": pid_
}))
workflow, applied = comfy.inspector.inject_overrides(user_template, overrides)
seed_used = applied.get("seed")
comfy.last_seed = seed_used
try:
images, videos = await comfy._general_generate(workflow, pid, on_progress)
except Exception as exc:
logger.exception("Workflow gen error")
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": str(exc)
})
return
comfy.last_prompt_id = pid
comfy.total_generated += 1
config = get_config()
try:
from generation_db import record_generation, record_file
gen_id = record_generation(pid, "web", user_label, overrides, seed_used)
for i, img_data in enumerate(images):
record_file(gen_id, f"image_{i:04d}.png", img_data)
if config and videos:
for vid in videos:
vsub = vid.get("video_subfolder", "")
vname = vid.get("video_name", "")
vpath = (
Path(config.comfy_output_path) / vsub / vname
if vsub
else Path(config.comfy_output_path) / vname
)
try:
record_file(gen_id, vname, vpath.read_bytes())
except OSError:
pass
except Exception as exc:
logger.warning("Failed to record generation to DB: %s", exc)
if config:
from media_uploader import flush_pending
asyncio.create_task(flush_pending(
Path(config.comfy_output_path),
config.media_upload_user,
config.media_upload_pass,
))
await bus.broadcast("queue_update", {"prompt_id": pid, "status": "complete"})
await bus.broadcast_to_user(user_label, "generation_complete", {
"prompt_id": pid,
"seed": seed_used,
"image_count": len(images),
"video_count": len(videos),
})
depth = await comfy.get_queue_depth()
for _ in range(count):
asyncio.create_task(_run_one())
return {
"queued": True,
"count": count,
"queue_position": depth + 1,
}

View File

@@ -0,0 +1,143 @@
"""GET /api/history; GET /api/history/{prompt_id}/images; GET /api/history/{prompt_id}/file/{filename};
POST /api/history/{prompt_id}/share; DELETE /api/history/{prompt_id}/share"""
from __future__ import annotations
import base64
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import Response
from web.auth import require_auth
router = APIRouter()
def _assert_owns(prompt_id: str, user: dict) -> None:
"""Raise 404 if the generation doesn't exist or doesn't belong to the user.
Returning the same 404 for both cases prevents leaking whether a
prompt_id exists to users who don't own it. Admins bypass this check.
"""
if user.get("admin"):
return
from generation_db import get_generation
gen = get_generation(prompt_id)
if gen is None or gen["user_label"] != user["sub"]:
raise HTTPException(404, "Not found")
@router.get("")
async def get_history(
user: dict = Depends(require_auth),
q: Optional[str] = Query(None, description="Keyword to search in overrides JSON"),
):
"""Return generation history. Admins see all; regular users see only their own.
Pass ?q=keyword to filter by prompt text or any override field."""
from generation_db import (
get_history as db_get_history,
get_history_for_user,
search_history,
search_history_for_user,
)
if q and q.strip():
if user.get("admin"):
return {"history": search_history(q.strip(), limit=50)}
return {"history": search_history_for_user(user["sub"], q.strip(), limit=50)}
if user.get("admin"):
return {"history": db_get_history(limit=50)}
return {"history": get_history_for_user(user["sub"], limit=50)}
@router.get("/{prompt_id}/images")
async def get_history_images(prompt_id: str, user: dict = Depends(require_auth)):
"""
Fetch output files for a past generation.
Returns base64-encoded blobs from the local SQLite DB — works even after
``flush_pending`` has deleted the files from disk.
"""
_assert_owns(prompt_id, user)
from generation_db import get_files
files = get_files(prompt_id)
if not files:
raise HTTPException(404, f"No files found for prompt_id {prompt_id!r}")
return {
"prompt_id": prompt_id,
"images": [
{
"filename": f["filename"],
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
"mime_type": f["mime_type"],
}
for f in files
],
}
@router.get("/{prompt_id}/file/{filename}")
async def get_history_file(
prompt_id: str,
filename: str,
request: Request,
user: dict = Depends(require_auth),
):
"""Stream a single output file, with HTTP range request support for video seeking."""
_assert_owns(prompt_id, user)
from generation_db import get_files
files = get_files(prompt_id)
matched = next((f for f in files if f["filename"] == filename), None)
if matched is None:
raise HTTPException(404, f"File {filename!r} not found for prompt_id {prompt_id!r}")
data: bytes = matched["data"]
mime: str = matched["mime_type"]
total = len(data)
range_header = request.headers.get("range")
if range_header:
range_val = range_header.replace("bytes=", "")
start_str, _, end_str = range_val.partition("-")
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else total - 1
end = min(end, total - 1)
chunk = data[start : end + 1]
return Response(
content=chunk,
status_code=206,
media_type=mime,
headers={
"Content-Range": f"bytes {start}-{end}/{total}",
"Accept-Ranges": "bytes",
"Content-Length": str(len(chunk)),
},
)
return Response(
content=data,
media_type=mime,
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total),
},
)
@router.post("/{prompt_id}/share")
async def create_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
"""Create a share token for a generation. Only the owner may share."""
# Use the same 404-for-everything helper to avoid leaking prompt_id existence
_assert_owns(prompt_id, user)
from generation_db import create_share as db_create_share
token = db_create_share(prompt_id, user["sub"])
return {"share_token": token}
@router.delete("/{prompt_id}/share")
async def revoke_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
"""Revoke a share token for a generation. Only the owner may revoke."""
from generation_db import revoke_share as db_revoke_share
deleted = db_revoke_share(prompt_id, user["sub"])
if not deleted:
raise HTTPException(404, "No active share found for this generation")
return {"ok": True}

View File

@@ -0,0 +1,194 @@
"""GET/POST/DELETE /api/inputs; GET /api/inputs/{id}/image; POST /api/inputs/{id}/activate"""
from __future__ import annotations
import logging
import mimetypes
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, UploadFile
from fastapi.responses import Response
from web.auth import require_auth
from web.deps import get_config, get_user_registry
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("")
async def list_inputs(_: dict = Depends(require_auth)):
"""List all input images (Discord + web uploads)."""
from input_image_db import get_all_images
rows = get_all_images()
return [dict(r) for r in rows]
@router.post("")
async def upload_input(
file: UploadFile = File(...),
slot_key: Optional[str] = Form(default=None),
user: dict = Depends(require_auth),
):
"""
Upload an input image.
Stores image bytes directly in SQLite. If *slot_key* is provided the
image is immediately activated for that slot (writes to ComfyUI input
folder and updates the user's state override).
The physical slot file uses a namespaced key ``<user_label>_<slot_key>``
so concurrent users each get their own active image file.
"""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
data = await file.read()
filename = file.filename or "upload.png"
from input_image_db import upsert_image, activate_image_for_slot
row_id = upsert_image(
original_message_id=0, # sentinel for web uploads
bot_reply_id=0,
channel_id=0,
filename=filename,
image_data=data,
)
activated_filename: str | None = None
if slot_key:
user_label: str = user["sub"]
namespaced_key = f"{user_label}_{slot_key}"
activated_filename = activate_image_for_slot(
row_id, namespaced_key, config.comfy_input_path
)
registry = get_user_registry()
if registry:
registry.get_state_manager(user_label).set_override(slot_key, activated_filename)
else:
from web.deps import get_comfy
comfy = get_comfy()
if comfy:
comfy.state_manager.set_override(slot_key, activated_filename)
return {"id": row_id, "filename": filename, "slot_key": slot_key, "activated_filename": activated_filename}
@router.post("/{row_id}/activate")
async def activate_input(
row_id: int,
slot_key: str = Body(default="input_image", embed=True),
user: dict = Depends(require_auth),
):
"""Write the stored image to the ComfyUI input folder and set the user's slot override."""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
from input_image_db import get_image, activate_image_for_slot
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
user_label: str = user["sub"]
namespaced_key = f"{user_label}_{slot_key}"
try:
filename = activate_image_for_slot(row_id, namespaced_key, config.comfy_input_path)
except ValueError as exc:
raise HTTPException(409, str(exc))
registry = get_user_registry()
if registry:
registry.get_state_manager(user_label).set_override(slot_key, filename)
else:
from web.deps import get_comfy
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "State manager not available")
comfy.state_manager.set_override(slot_key, filename)
return {"ok": True, "slot_key": slot_key, "filename": filename}
@router.delete("/{row_id}")
async def delete_input(row_id: int, _: dict = Depends(require_auth)):
"""Delete an input image record (and its active slot file if applicable)."""
from input_image_db import get_image, delete_image
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
config = get_config()
delete_image(row_id, comfy_input_path=config.comfy_input_path if config else None)
return {"ok": True}
@router.get("/{row_id}/image")
async def get_input_image(row_id: int, _: dict = Depends(require_auth)):
"""Serve the raw image bytes stored in the database for a given input image row."""
from input_image_db import get_image, get_image_data
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
data = get_image_data(row_id)
if data is None:
raise HTTPException(404, "Image data not available — re-upload to backfill")
mime, _ = mimetypes.guess_type(row["filename"])
return Response(content=data, media_type=mime or "application/octet-stream")
def _pil_resize_response(data: bytes, filename: str, max_size: int, quality: int) -> Response:
"""Resize image bytes with Pillow and return a JPEG Response. Raises on failure."""
import io
from PIL import Image as _PIL
img = _PIL.open(io.BytesIO(data))
img.thumbnail((max_size, max_size), _PIL.LANCZOS)
buf = io.BytesIO()
img.convert("RGB").save(buf, "JPEG", quality=quality, optimize=True)
return Response(
content=buf.getvalue(),
media_type="image/jpeg",
headers={"Cache-Control": "public, max-age=86400"},
)
@router.get("/{row_id}/thumb")
async def get_input_thumb(row_id: int, _: dict = Depends(require_auth)):
"""Serve a small compressed thumbnail (max 200 px, JPEG 65 %) for fast previews."""
from input_image_db import get_image, get_image_data
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
data = get_image_data(row_id)
if data is None:
raise HTTPException(404, "Image data not available — re-upload to backfill")
try:
return _pil_resize_response(data, row["filename"], max_size=200, quality=65)
except Exception:
mime, _ = mimetypes.guess_type(row["filename"])
return Response(content=data, media_type=mime or "application/octet-stream")
@router.get("/{row_id}/mid")
async def get_input_mid(row_id: int, _: dict = Depends(require_auth)):
"""Serve a medium compressed image (max 800 px, JPEG 80 %) for progressive loading."""
from input_image_db import get_image, get_image_data
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
data = get_image_data(row_id)
if data is None:
raise HTTPException(404, "Image data not available — re-upload to backfill")
try:
return _pil_resize_response(data, row["filename"], max_size=800, quality=80)
except Exception:
mime, _ = mimetypes.guess_type(row["filename"])
return Response(content=data, media_type=mime or "application/octet-stream")

View File

@@ -0,0 +1,153 @@
"""CRUD for workflow presets via /api/presets"""
from __future__ import annotations
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from web.auth import require_auth
from web.deps import get_comfy, get_user_registry
router = APIRouter()
class SavePresetRequest(BaseModel):
name: str
description: Optional[str] = None
class SaveFromHistoryRequest(BaseModel):
name: str
description: Optional[str] = None
def _get_pm():
from web.deps import get_bot
bot = get_bot()
pm = getattr(bot, "preset_manager", None) if bot else None
if pm is None:
from preset_manager import PresetManager
pm = PresetManager()
return pm
@router.get("")
async def list_presets(_: dict = Depends(require_auth)):
pm = _get_pm()
return {"presets": pm.list_preset_details()}
@router.post("")
async def save_preset(body: SavePresetRequest, user: dict = Depends(require_auth)):
"""Capture the user's overrides + workflow template as a named preset."""
user_label: str = user["sub"]
registry = get_user_registry()
pm = _get_pm()
if registry:
workflow_template = registry.get_workflow_template(user_label)
overrides = registry.get_state_manager(user_label).get_overrides()
else:
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI not available")
workflow_template = comfy.get_workflow_template()
overrides = comfy.state_manager.get_overrides()
try:
pm.save(body.name, workflow_template, overrides, owner=user_label, description=body.description)
except ValueError as exc:
raise HTTPException(400, str(exc))
return {"ok": True, "name": body.name}
@router.get("/{name}")
async def get_preset(name: str, _: dict = Depends(require_auth)):
pm = _get_pm()
data = pm.load(name)
if data is None:
raise HTTPException(404, "Preset not found")
return data
@router.post("/{name}/load")
async def load_preset(name: str, user: dict = Depends(require_auth)):
"""Restore overrides (and optionally workflow template) from a preset into the user's state."""
pm = _get_pm()
data = pm.load(name)
if data is None:
raise HTTPException(404, "Preset not found")
user_label: str = user["sub"]
registry = get_user_registry()
if registry:
wf = data.get("workflow")
if wf:
registry.set_workflow(user_label, wf, name)
else:
# No workflow in preset — just clear overrides and restore state
registry.get_state_manager(user_label).clear_overrides()
state = data.get("state", {})
sm = registry.get_state_manager(user_label)
for k, v in state.items():
if v is not None:
sm.set_override(k, v)
else:
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI not available")
comfy.state_manager.clear_overrides()
state = data.get("state", {})
for k, v in state.items():
if v is not None:
comfy.state_manager.set_override(k, v)
wf = data.get("workflow")
if wf:
comfy.workflow_manager.set_workflow_template(wf)
return {"ok": True, "name": name, "overrides_restored": list(data.get("state", {}).keys())}
@router.delete("/{name}")
async def delete_preset(name: str, user: dict = Depends(require_auth)):
pm = _get_pm()
data = pm.load(name)
if data is None:
raise HTTPException(404, "Preset not found")
user_label: str = user["sub"]
is_admin = user.get("admin") is True
owner = data.get("owner")
if owner is not None and owner != user_label and not is_admin:
raise HTTPException(403, "You do not have permission to delete this preset")
pm.delete(name)
return {"ok": True}
@router.post("/from-history/{prompt_id}")
async def save_preset_from_history(
prompt_id: str,
body: SaveFromHistoryRequest,
user: dict = Depends(require_auth),
):
"""Create a preset from a past generation's overrides."""
from generation_db import get_generation_full
gen = get_generation_full(prompt_id)
if gen is None:
raise HTTPException(404, "Generation not found")
user_label: str = user["sub"]
is_admin = user.get("admin") is True
if not is_admin and gen.get("user_label") != user_label:
raise HTTPException(404, "Generation not found")
pm = _get_pm()
try:
pm.save(body.name, None, gen["overrides"], owner=user_label, description=body.description)
except ValueError as exc:
raise HTTPException(400, str(exc))
return {"ok": True, "name": body.name}

View File

@@ -0,0 +1,90 @@
"""GET/POST /api/server/{action}; GET /api/logs/tail"""
from __future__ import annotations
import logging
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException
from web.auth import require_auth
from web.deps import get_config, get_comfy
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/server/status")
async def server_status(_: dict = Depends(require_auth)):
"""Return NSSM service state and HTTP health."""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
from commands.server import get_service_state
import asyncio
async def _false():
return False
comfy = get_comfy()
service_state, http_ok = await asyncio.gather(
get_service_state(config.comfy_service_name),
comfy.check_connection() if comfy else _false(),
)
return {"service_state": service_state, "http_reachable": http_ok}
@router.post("/server/{action}")
async def server_action(action: str, _: dict = Depends(require_auth)):
"""Control the ComfyUI service: start | stop | restart | install | uninstall"""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
valid_actions = {"start", "stop", "restart", "install", "uninstall"}
if action not in valid_actions:
raise HTTPException(400, f"Invalid action '{action}'")
from commands.server import _nssm, _install_service
import asyncio
try:
if action == "install":
ok, msg = await _install_service(config)
if not ok:
raise HTTPException(500, msg)
elif action == "uninstall":
await _nssm("stop", config.comfy_service_name)
await _nssm("remove", config.comfy_service_name, "confirm")
elif action == "start":
await _nssm("start", config.comfy_service_name)
elif action == "stop":
await _nssm("stop", config.comfy_service_name)
elif action == "restart":
await _nssm("restart", config.comfy_service_name)
except HTTPException:
raise
except Exception as exc:
raise HTTPException(500, str(exc))
return {"ok": True, "action": action}
@router.get("/logs/tail")
async def tail_logs(lines: int = 100, _: dict = Depends(require_auth)):
"""Tail the ComfyUI log file."""
config = get_config()
if config is None or not config.comfy_log_dir:
raise HTTPException(503, "Log directory not configured")
log_dir = Path(config.comfy_log_dir)
log_file = log_dir / "comfyui.log"
if not log_file.exists():
return {"lines": []}
try:
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
all_lines = f.readlines()
tail = all_lines[-min(lines, len(all_lines)):]
return {"lines": [ln.rstrip("\n") for ln in tail]}
except Exception as exc:
raise HTTPException(500, str(exc))

View File

@@ -0,0 +1,85 @@
"""GET /api/share/{token}; GET /api/share/{token}/file/{filename}"""
from __future__ import annotations
import base64
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from web.auth import require_auth
router = APIRouter()
@router.get("/{token}")
async def get_share(token: str, _: dict = Depends(require_auth)):
"""Fetch share metadata and images. Any authenticated user may view a valid share link."""
from generation_db import get_share_by_token, get_files
gen = get_share_by_token(token)
if gen is None:
raise HTTPException(404, "Share not found or revoked")
files = get_files(gen["prompt_id"])
return {
"prompt_id": gen["prompt_id"],
"created_at": gen["created_at"],
"overrides": gen["overrides"],
"seed": gen["seed"],
"images": [
{
"filename": f["filename"],
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
"mime_type": f["mime_type"],
}
for f in files
],
}
@router.get("/{token}/file/{filename}")
async def get_share_file(
token: str,
filename: str,
request: Request,
_: dict = Depends(require_auth),
):
"""Stream a single output file via share token, with HTTP range support for video seeking."""
from generation_db import get_share_by_token, get_files
gen = get_share_by_token(token)
if gen is None:
raise HTTPException(404, "Share not found or revoked")
files = get_files(gen["prompt_id"])
matched = next((f for f in files if f["filename"] == filename), None)
if matched is None:
raise HTTPException(404, f"File {filename!r} not found")
data: bytes = matched["data"]
mime: str = matched["mime_type"]
total = len(data)
range_header = request.headers.get("range")
if range_header:
range_val = range_header.replace("bytes=", "")
start_str, _, end_str = range_val.partition("-")
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else total - 1
end = min(end, total - 1)
chunk = data[start : end + 1]
return Response(
content=chunk,
status_code=206,
media_type=mime,
headers={
"Content-Range": f"bytes {start}-{end}/{total}",
"Accept-Ranges": "bytes",
"Content-Length": str(len(chunk)),
},
)
return Response(
content=data,
media_type=mime,
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total),
},
)

View File

@@ -0,0 +1,53 @@
"""GET/PUT /api/state; DELETE /api/state/{key}"""
from __future__ import annotations
from typing import Any, Dict
from fastapi import APIRouter, Depends, HTTPException
from web.auth import require_auth
from web.deps import get_config, get_user_registry
router = APIRouter()
def _get_user_sm(user: dict):
"""Return the per-user WorkflowStateManager, raising 503 if unavailable."""
registry = get_user_registry()
if registry is None:
raise HTTPException(503, "State manager not available")
return registry.get_state_manager(user["sub"])
@router.get("/state")
async def get_state(user: dict = Depends(require_auth)):
"""Return all current overrides for the authenticated user."""
sm = _get_user_sm(user)
return sm.get_overrides()
@router.put("/state")
async def put_state(body: Dict[str, Any], user: dict = Depends(require_auth)):
"""Merge override values. Pass ``null`` as a value to delete a key."""
sm = _get_user_sm(user)
for key, value in body.items():
if value is None:
sm.delete_override(key)
else:
sm.set_override(key, value)
return sm.get_overrides()
@router.delete("/state/{key}")
async def delete_state_key(key: str, user: dict = Depends(require_auth)):
"""Remove a single override key, and clean up any associated slot file."""
sm = _get_user_sm(user)
sm.delete_override(key)
config = get_config()
if config:
from input_image_db import deactivate_image_slot
user_label: str = user["sub"]
deactivate_image_slot(f"{user_label}_{key}", config.comfy_input_path)
return {"ok": True, "key": key}

View File

@@ -0,0 +1,74 @@
"""GET /api/status — polling fallback for clients that can't use WebSocket"""
from __future__ import annotations
import asyncio
from fastapi import APIRouter, Depends
from web.auth import require_auth
from web.deps import get_bot, get_comfy, get_config
router = APIRouter()
@router.get("/status")
async def get_status(_: dict = Depends(require_auth)):
"""Return a full status snapshot."""
bot = get_bot()
comfy = get_comfy()
config = get_config()
snap: dict = {}
if bot is not None:
import datetime as _dt
lat = bot.latency
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
start = getattr(bot, "start_time", None)
uptime = ""
if start:
delta = _dt.datetime.now(_dt.timezone.utc) - start
total = int(delta.total_seconds())
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
snap["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
if comfy is not None:
q_task = asyncio.create_task(comfy.get_comfy_queue())
conn_task = asyncio.create_task(comfy.check_connection())
q, reachable = await asyncio.gather(q_task, conn_task)
pending = len(q.get("queue_pending", [])) if q else 0
running = len(q.get("queue_running", [])) if q else 0
wm = getattr(comfy, "workflow_manager", None)
wf_loaded = wm is not None and wm.get_workflow_template() is not None
snap["comfy"] = {
"server": comfy.server_address,
"reachable": reachable,
"queue_pending": pending,
"queue_running": running,
"workflow_loaded": wf_loaded,
"last_seed": comfy.last_seed,
"total_generated": comfy.total_generated,
}
snap["overrides"] = comfy.state_manager.get_overrides()
if config is not None:
from commands.server import get_service_state
service_state = await get_service_state(config.comfy_service_name)
snap["service"] = {"state": service_state}
try:
from media_uploader import get_stats as us_fn, is_running as ur_fn
us = us_fn()
snap["upload"] = {
"configured": bool(config.media_upload_user),
"running": ur_fn(),
"total_ok": us.total_ok,
"total_fail": us.total_fail,
}
except Exception:
pass
return snap

View File

@@ -0,0 +1,175 @@
"""
GET /api/workflow — current workflow info
GET /api/workflow/inputs — dynamic NodeInput list
GET /api/workflow/files — list files in workflows/
POST /api/workflow/upload — upload a workflow JSON
POST /api/workflow/load — load a workflow from workflows/
GET /api/workflow/models?type=checkpoints|loras — available models (60s TTL cache)
"""
from __future__ import annotations
import json
import logging
import time
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from web.auth import require_auth
from web.deps import get_comfy, get_config, get_inspector, get_user_registry
router = APIRouter()
logger = logging.getLogger(__name__)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
_WORKFLOWS_DIR = _PROJECT_ROOT / "workflows"
# Simple in-memory TTL cache for models
_models_cache: dict = {}
_MODELS_TTL = 60.0
@router.get("")
async def get_workflow(user: dict = Depends(require_auth)):
"""Return basic info about the currently loaded workflow."""
user_label: str = user["sub"]
registry = get_user_registry()
if registry:
template = registry.get_workflow_template(user_label)
last_wf = registry.get_state_manager(user_label).get_last_workflow_file()
else:
# Fallback to global state when registry is unavailable
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "Workflow manager not available")
template = comfy.workflow_manager.get_workflow_template()
last_wf = comfy.state_manager.get_last_workflow_file()
return {
"loaded": template is not None,
"node_count": len(template) if template else 0,
"last_workflow_file": last_wf,
}
@router.get("/inputs")
async def get_workflow_inputs(user: dict = Depends(require_auth)):
"""Return dynamic NodeInput list (common + advanced) for the current workflow."""
user_label: str = user["sub"]
inspector = get_inspector()
if inspector is None:
raise HTTPException(503, "Workflow components not available")
registry = get_user_registry()
if registry:
template = registry.get_workflow_template(user_label)
overrides = registry.get_state_manager(user_label).get_overrides()
else:
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "Workflow components not available")
template = comfy.workflow_manager.get_workflow_template()
overrides = comfy.state_manager.get_overrides()
if template is None:
return {"common": [], "advanced": []}
inputs = inspector.inspect(template)
result = []
for ni in inputs:
val = overrides.get(ni.key, ni.current_value)
result.append({
"key": ni.key,
"label": ni.label,
"input_type": ni.input_type,
"current_value": val,
"node_class": ni.node_class,
"node_title": ni.node_title,
"is_common": ni.is_common,
})
common = [r for r in result if r["is_common"]]
advanced = [r for r in result if not r["is_common"]]
return {"common": common, "advanced": advanced}
@router.get("/files")
async def list_workflow_files(_: dict = Depends(require_auth)):
"""List .json files in the workflows/ folder."""
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
files = sorted(p.name for p in _WORKFLOWS_DIR.glob("*.json"))
return {"files": files}
@router.post("/upload")
async def upload_workflow(
file: UploadFile = File(...),
_: dict = Depends(require_auth),
):
"""Upload a workflow JSON to the workflows/ folder."""
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
filename = file.filename or "workflow.json"
if not filename.endswith(".json"):
filename += ".json"
data = await file.read()
try:
json.loads(data) # validate JSON
except json.JSONDecodeError as exc:
raise HTTPException(400, f"Invalid JSON: {exc}")
dest = _WORKFLOWS_DIR / filename
dest.write_bytes(data)
return {"ok": True, "filename": filename}
@router.post("/load")
async def load_workflow(filename: str = Form(...), user: dict = Depends(require_auth)):
"""Load a workflow from the workflows/ folder into the user's isolated state."""
wf_path = _WORKFLOWS_DIR / filename
if not wf_path.exists():
raise HTTPException(404, f"Workflow file '{filename}' not found")
try:
with open(wf_path, "r", encoding="utf-8") as f:
workflow = json.load(f)
except Exception as exc:
raise HTTPException(500, str(exc))
registry = get_user_registry()
if registry:
registry.set_workflow(user["sub"], workflow, filename)
else:
# Fallback: update global state when registry unavailable
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI not available")
comfy.workflow_manager.set_workflow_template(workflow)
comfy.state_manager.clear_overrides()
comfy.state_manager.set_last_workflow_file(filename)
inspector = get_inspector()
node_count = len(workflow)
inputs_count = len(inspector.inspect(workflow)) if inspector else 0
return {
"ok": True,
"filename": filename,
"node_count": node_count,
"inputs_count": inputs_count,
}
@router.get("/models")
async def get_models(type: str = "checkpoints", _: dict = Depends(require_auth)):
"""Return available model names from ComfyUI (60s TTL cache)."""
global _models_cache
now = time.time()
cache_key = type
cached = _models_cache.get(cache_key)
if cached and (now - cached["ts"]) < _MODELS_TTL:
return {"type": type, "models": cached["models"]}
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI not available")
models = await comfy.get_models(type)
_models_cache[cache_key] = {"models": models, "ts": now}
return {"type": type, "models": models}

61
web/routers/ws_router.py Normal file
View File

@@ -0,0 +1,61 @@
"""WebSocket /ws?token=<jwt> — real-time event stream"""
from __future__ import annotations
import asyncio
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from web.auth import verify_ws_token
from web.ws_bus import get_bus
router = APIRouter()
logger = logging.getLogger(__name__)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = ""):
"""
Authenticate via JWT query param or ttb_session cookie, then stream events from WSBus.
Events common to all users: status_snapshot, queue_update, node_executing, server_state
Events private to submitter: generation_complete, generation_error
"""
payload = verify_ws_token(token)
if payload is None:
# Fallback: browsers send cookies automatically with WebSocket connections
cookie_token = websocket.cookies.get("ttb_session", "")
payload = verify_ws_token(cookie_token)
if payload is None:
await websocket.close(code=4001, reason="Unauthorized")
return
user_label: str = payload.get("sub", "anonymous")
bus = get_bus()
queue = bus.subscribe(user_label)
await websocket.accept()
logger.info("WS connected: user=%s", user_label)
try:
while True:
# Wait for an event from the bus
try:
frame = await asyncio.wait_for(queue.get(), timeout=30.0)
except asyncio.TimeoutError:
# Send a keepalive ping
try:
await websocket.send_text('{"type":"ping"}')
except Exception:
break
continue
try:
await websocket.send_text(frame)
except Exception:
break
except WebSocketDisconnect:
pass
finally:
bus.unsubscribe(user_label, queue)
logger.info("WS disconnected: user=%s", user_label)

122
web/ws_bus.py Normal file
View File

@@ -0,0 +1,122 @@
"""
web/ws_bus.py
=============
In-process WebSocket event bus.
All connected web clients share a single WSBus instance. Events are
delivered per-user (private results) or to all users (shared status).
Usage::
bus = WSBus()
# Subscribe (returns a queue; caller reads from it)
q = bus.subscribe("alice")
# Broadcast to all
await bus.broadcast("status_snapshot", {...})
# Broadcast to one user (all their open tabs)
await bus.broadcast_to_user("alice", "generation_complete", {...})
# Unsubscribe when WS disconnects
bus.unsubscribe("alice", q)
Event frame format sent on wire:
{"type": "event_name", "data": {...}, "ts": 1234567890.123}
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import Any, Dict, Set
logger = logging.getLogger(__name__)
class WSBus:
"""
Per-user broadcast bus backed by asyncio queues.
Thread-safe as long as all callers run in the same event loop.
"""
def __init__(self) -> None:
# user_label → set of asyncio.Queue
self._clients: Dict[str, Set[asyncio.Queue]] = {}
# ------------------------------------------------------------------
# Subscription lifecycle
# ------------------------------------------------------------------
def subscribe(self, user_label: str) -> asyncio.Queue:
"""Register a new client connection. Returns the queue to read from."""
q: asyncio.Queue = asyncio.Queue(maxsize=256)
self._clients.setdefault(user_label, set()).add(q)
logger.debug("WSBus: %s subscribed (%d queues)", user_label,
len(self._clients[user_label]))
return q
def unsubscribe(self, user_label: str, queue: asyncio.Queue) -> None:
"""Remove a client connection."""
queues = self._clients.get(user_label, set())
queues.discard(queue)
if not queues:
self._clients.pop(user_label, None)
logger.debug("WSBus: %s unsubscribed", user_label)
@property
def connected_users(self) -> list[str]:
"""List of user labels with at least one active connection."""
return list(self._clients.keys())
@property
def total_connections(self) -> int:
return sum(len(qs) for qs in self._clients.values())
# ------------------------------------------------------------------
# Broadcasting
# ------------------------------------------------------------------
def _frame(self, event_type: str, data: Any) -> str:
return json.dumps({"type": event_type, "data": data, "ts": time.time()})
async def broadcast(self, event_type: str, data: Any) -> None:
"""Send an event to ALL connected clients."""
frame = self._frame(event_type, data)
for queues in list(self._clients.values()):
for q in list(queues):
try:
q.put_nowait(frame)
except asyncio.QueueFull:
logger.warning("WSBus: queue full, dropping %s event", event_type)
async def broadcast_to_user(
self, user_label: str, event_type: str, data: Any
) -> None:
"""Send an event to all connections belonging to *user_label*."""
queues = self._clients.get(user_label, set())
if not queues:
logger.debug("WSBus: no clients for user '%s', dropping %s", user_label, event_type)
return
frame = self._frame(event_type, data)
for q in list(queues):
try:
q.put_nowait(frame)
except asyncio.QueueFull:
logger.warning("WSBus: queue full for %s, dropping %s", user_label, event_type)
# Module-level singleton (set by web/app.py)
_bus: WSBus | None = None
def get_bus() -> WSBus:
global _bus
if _bus is None:
_bus = WSBus()
return _bus