Initial commit — ComfyUI Discord bot + web UI
Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI (React/TS/Vite/Tailwind), ComfyUI integration, generation history DB, preset manager, workflow inspector, and all supporting modules. Excluded from tracking: .env, invite_tokens.json, *.db (SQLite), current-workflow-changes.json, user_settings/, presets/, logs/, web-static/ (build output), frontend/node_modules/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
1
web/__init__.py
Normal file
1
web/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# web package
|
||||
269
web/app.py
Normal file
269
web/app.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
web/app.py
|
||||
==========
|
||||
|
||||
FastAPI application factory.
|
||||
|
||||
The app is created once and shared between the Uvicorn server (started
|
||||
from bot.py via asyncio.gather) and tests.
|
||||
|
||||
Startup tasks:
|
||||
- Background status ticker (broadcasts status_snapshot every 5s to all clients)
|
||||
- Background NSSM poll (broadcasts server_state every 10s to all clients)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from starlette.exceptions import HTTPException as _HTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request as _Request
|
||||
|
||||
# Windows registry can map .js → text/plain; override to the correct types
|
||||
# before StaticFiles reads them.
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("application/javascript", ".mjs")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
mimetypes.add_type("application/wasm", ".wasm")
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
|
||||
class _NoCacheHTMLMiddleware(BaseHTTPMiddleware):
|
||||
"""Force browsers to revalidate index.html on every request.
|
||||
|
||||
Vite hashes JS/CSS filenames on every build so those assets are
|
||||
naturally cache-busted. index.html itself has a stable name, so
|
||||
without an explicit Cache-Control header mobile browsers apply
|
||||
heuristic caching and keep serving a stale copy after a redeploy.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: _Request, call_next):
|
||||
response = await call_next(request)
|
||||
ct = response.headers.get("content-type", "")
|
||||
if "text/html" in ct:
|
||||
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
||||
return response
|
||||
|
||||
|
||||
class _SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to every response."""
|
||||
|
||||
_CSP = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: blob:; "
|
||||
"connect-src 'self' wss:; "
|
||||
"frame-ancestors 'none';"
|
||||
)
|
||||
|
||||
async def dispatch(self, request: _Request, call_next):
|
||||
response = await call_next(request)
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Content-Security-Policy"] = self._CSP
|
||||
return response
|
||||
|
||||
class _SPAStaticFiles(StaticFiles):
|
||||
"""StaticFiles with SPA fallback: serve index.html for unknown paths.
|
||||
|
||||
Starlette's html=True only serves index.html for directory requests.
|
||||
This subclass additionally returns index.html for any path that has no
|
||||
matching file, so client-side routes like /generate work on refresh.
|
||||
"""
|
||||
|
||||
async def get_response(self, path: str, scope):
|
||||
try:
|
||||
return await super().get_response(path, scope)
|
||||
except _HTTPException as ex:
|
||||
if ex.status_code == 404:
|
||||
return await super().get_response("index.html", scope)
|
||||
raise
|
||||
|
||||
|
||||
from web.ws_bus import get_bus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
_WEB_STATIC = _PROJECT_ROOT / "web-static"
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
app = FastAPI(
|
||||
title="ComfyUI Bot Web UI",
|
||||
version="1.0.0",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
openapi_url=None,
|
||||
)
|
||||
|
||||
# CORS — only allow explicitly configured origins; empty = no cross-origin
|
||||
_cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()]
|
||||
if _cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Security headers on every response
|
||||
app.add_middleware(_SecurityHeadersMiddleware)
|
||||
|
||||
# Prevent browsers from caching index.html across deploys
|
||||
app.add_middleware(_NoCacheHTMLMiddleware)
|
||||
|
||||
# Register API routers
|
||||
from web.routers.auth_router import router as auth_router
|
||||
from web.routers.admin_router import router as admin_router
|
||||
from web.routers.status_router import router as status_router
|
||||
from web.routers.state_router import router as state_router
|
||||
from web.routers.generate_router import router as generate_router
|
||||
from web.routers.inputs_router import router as inputs_router
|
||||
from web.routers.presets_router import router as presets_router
|
||||
from web.routers.server_router import router as server_router
|
||||
from web.routers.history_router import router as history_router
|
||||
from web.routers.share_router import router as share_router
|
||||
from web.routers.workflow_router import router as workflow_router
|
||||
from web.routers.ws_router import router as ws_router
|
||||
|
||||
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
||||
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
|
||||
app.include_router(status_router, prefix="/api", tags=["status"])
|
||||
app.include_router(state_router, prefix="/api", tags=["state"])
|
||||
app.include_router(generate_router, prefix="/api", tags=["generate"])
|
||||
app.include_router(inputs_router, prefix="/api/inputs", tags=["inputs"])
|
||||
app.include_router(presets_router, prefix="/api/presets", tags=["presets"])
|
||||
app.include_router(server_router, prefix="/api", tags=["server"])
|
||||
app.include_router(history_router, prefix="/api/history", tags=["history"])
|
||||
app.include_router(share_router, prefix="/api/share", tags=["share"])
|
||||
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
|
||||
app.include_router(ws_router, tags=["ws"])
|
||||
|
||||
# Serve frontend static files (if built)
|
||||
if _WEB_STATIC.exists() and any(_WEB_STATIC.iterdir()):
|
||||
app.mount("/", _SPAStaticFiles(directory=str(_WEB_STATIC), html=True), name="static")
|
||||
logger.info("Serving frontend from %s", _WEB_STATIC)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup():
|
||||
asyncio.create_task(_status_ticker())
|
||||
asyncio.create_task(_server_state_poller())
|
||||
logger.info("Web background tasks started")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Background tasks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _status_ticker() -> None:
|
||||
"""Broadcast status_snapshot to all clients every 5 seconds."""
|
||||
from web.deps import get_bot, get_comfy, get_config
|
||||
bus = get_bus()
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
try:
|
||||
bot = get_bot()
|
||||
comfy = get_comfy()
|
||||
config = get_config()
|
||||
|
||||
snapshot: dict = {}
|
||||
|
||||
if bot is not None:
|
||||
lat = bot.latency
|
||||
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
|
||||
import datetime as _dt
|
||||
start = getattr(bot, "start_time", None)
|
||||
uptime = ""
|
||||
if start:
|
||||
delta = _dt.datetime.now(_dt.timezone.utc) - start
|
||||
total = int(delta.total_seconds())
|
||||
h, rem = divmod(total, 3600)
|
||||
m, s = divmod(rem, 60)
|
||||
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
|
||||
snapshot["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
|
||||
|
||||
if comfy is not None:
|
||||
q = await comfy.get_comfy_queue()
|
||||
pending = len(q.get("queue_pending", [])) if q else 0
|
||||
running = len(q.get("queue_running", [])) if q else 0
|
||||
wm = getattr(comfy, "workflow_manager", None)
|
||||
wf_loaded = wm is not None and wm.get_workflow_template() is not None
|
||||
snapshot["comfy"] = {
|
||||
"server": comfy.server_address,
|
||||
"queue_pending": pending,
|
||||
"queue_running": running,
|
||||
"workflow_loaded": wf_loaded,
|
||||
"last_seed": comfy.last_seed,
|
||||
"total_generated": comfy.total_generated,
|
||||
}
|
||||
|
||||
if config is not None:
|
||||
from media_uploader import get_stats as get_upload_stats, is_running as upload_running
|
||||
try:
|
||||
us = get_upload_stats()
|
||||
snapshot["upload"] = {
|
||||
"configured": bool(config.media_upload_user),
|
||||
"running": upload_running(),
|
||||
"total_ok": us.total_ok,
|
||||
"total_fail": us.total_fail,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from web.deps import get_user_registry
|
||||
registry = get_user_registry()
|
||||
connected = bus.connected_users
|
||||
if connected and registry:
|
||||
for ul in connected:
|
||||
user_overrides = registry.get_state_manager(ul).get_overrides()
|
||||
await bus.broadcast_to_user(ul, "status_snapshot", {**snapshot, "overrides": user_overrides})
|
||||
else:
|
||||
await bus.broadcast("status_snapshot", snapshot)
|
||||
except Exception as exc:
|
||||
logger.debug("Status ticker error: %s", exc)
|
||||
|
||||
|
||||
async def _server_state_poller() -> None:
|
||||
"""Poll NSSM service state and broadcast server_state every 10 seconds."""
|
||||
from web.deps import get_config
|
||||
bus = get_bus()
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
try:
|
||||
config = get_config()
|
||||
if config is None:
|
||||
continue
|
||||
from commands.server import get_service_state
|
||||
from web.deps import get_comfy
|
||||
|
||||
async def _false():
|
||||
return False
|
||||
|
||||
comfy = get_comfy()
|
||||
service_state, http_reachable = await asyncio.gather(
|
||||
get_service_state(config.comfy_service_name),
|
||||
comfy.check_connection() if comfy else _false(),
|
||||
)
|
||||
await bus.broadcast("server_state", {
|
||||
"state": service_state,
|
||||
"http_reachable": http_reachable,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.debug("Server state poller error: %s", exc)
|
||||
119
web/auth.py
Normal file
119
web/auth.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
web/auth.py
|
||||
===========
|
||||
|
||||
JWT authentication for the web UI.
|
||||
|
||||
Flow:
|
||||
- POST /api/auth/login {token} → verify invite token → issue JWT in httpOnly cookie
|
||||
- All /api/* require valid JWT via require_auth dependency
|
||||
- POST /api/admin/login {password} → issue admin JWT (admin: true claim)
|
||||
- WS /ws?token=<jwt> → authenticate via query param
|
||||
|
||||
JWT claims: {"sub": "<label>", "admin": bool, "exp": ...}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer
|
||||
|
||||
try:
|
||||
from jose import JWTError, jwt
|
||||
except ImportError:
|
||||
jwt = None # type: ignore
|
||||
JWTError = Exception # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
_COOKIE_NAME = "ttb_session"
|
||||
|
||||
|
||||
def _get_secret() -> str:
|
||||
from web.deps import get_config
|
||||
cfg = get_config()
|
||||
if cfg and cfg.web_secret_key:
|
||||
return cfg.web_secret_key
|
||||
raise RuntimeError(
|
||||
"WEB_SECRET_KEY must be set in the environment — "
|
||||
"refusing to run with an insecure default."
|
||||
)
|
||||
|
||||
|
||||
def create_jwt(label: str, *, admin: bool = False, expire_hours: int = 8) -> str:
|
||||
"""Create a signed JWT for the given user label."""
|
||||
if jwt is None:
|
||||
raise RuntimeError("python-jose is not installed (pip install python-jose[cryptography])")
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=expire_hours)
|
||||
payload = {"sub": label, "admin": admin, "exp": expire}
|
||||
return jwt.encode(payload, _get_secret(), algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_jwt(token: str) -> Optional[dict]:
|
||||
"""Decode and verify a JWT. Returns the payload or None on failure."""
|
||||
if jwt is None:
|
||||
return None
|
||||
try:
|
||||
return jwt.decode(token, _get_secret(), algorithms=[ALGORITHM])
|
||||
except JWTError as exc:
|
||||
logger.debug("JWT decode failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def verify_ws_token(token: str) -> Optional[dict]:
|
||||
"""Verify a JWT passed as a WebSocket query parameter."""
|
||||
return decode_jwt(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI dependencies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def require_auth(ttb_session: Optional[str] = Cookie(default=None)) -> dict:
|
||||
"""
|
||||
FastAPI dependency that requires a valid JWT cookie.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The decoded JWT payload (``sub``, ``admin`` fields).
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException 401
|
||||
If the cookie is absent or the token is invalid/expired.
|
||||
"""
|
||||
if not ttb_session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
)
|
||||
payload = decode_jwt(ttb_session)
|
||||
if payload is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def require_admin(user: dict = Depends(require_auth)) -> dict:
|
||||
"""
|
||||
FastAPI dependency that requires an admin JWT.
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException 403
|
||||
If the token is valid but not admin.
|
||||
"""
|
||||
if not user.get("admin"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin access required",
|
||||
)
|
||||
return user
|
||||
72
web/deps.py
Normal file
72
web/deps.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
web/deps.py
|
||||
===========
|
||||
|
||||
Shared bot reference for FastAPI dependency injection.
|
||||
|
||||
``set_bot()`` is called once from ``bot.py`` before starting Uvicorn.
|
||||
FastAPI route handlers use ``get_bot()``, ``get_comfy()``, etc. as
|
||||
Depends() callables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
_bot = None
|
||||
|
||||
|
||||
def set_bot(bot) -> None:
|
||||
"""Store the discord.py bot instance for DI access."""
|
||||
global _bot
|
||||
_bot = bot
|
||||
|
||||
|
||||
def get_bot():
|
||||
"""FastAPI dependency: return the bot instance."""
|
||||
return _bot
|
||||
|
||||
|
||||
def get_comfy():
|
||||
"""FastAPI dependency: return the ComfyClient."""
|
||||
if _bot is None:
|
||||
return None
|
||||
return getattr(_bot, "comfy", None)
|
||||
|
||||
|
||||
def get_config():
|
||||
"""FastAPI dependency: return the BotConfig."""
|
||||
if _bot is None:
|
||||
return None
|
||||
return getattr(_bot, "config", None)
|
||||
|
||||
|
||||
def get_state_manager():
|
||||
"""FastAPI dependency: return the WorkflowStateManager."""
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
return None
|
||||
return getattr(comfy, "state_manager", None)
|
||||
|
||||
|
||||
def get_workflow_manager():
|
||||
"""FastAPI dependency: return the WorkflowManager."""
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
return None
|
||||
return getattr(comfy, "workflow_manager", None)
|
||||
|
||||
|
||||
def get_inspector():
|
||||
"""FastAPI dependency: return the WorkflowInspector."""
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
return None
|
||||
return getattr(comfy, "inspector", None)
|
||||
|
||||
|
||||
def get_user_registry():
|
||||
"""FastAPI dependency: return the UserStateRegistry."""
|
||||
if _bot is None:
|
||||
return None
|
||||
return getattr(_bot, "user_registry", None)
|
||||
146
web/login_guard.py
Normal file
146
web/login_guard.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
web/login_guard.py
|
||||
==================
|
||||
|
||||
IP-based brute-force protection for login endpoints.
|
||||
|
||||
Tracks failed login attempts per IP in a rolling time window and issues a
|
||||
temporary ban when the threshold is exceeded. Uses only stdlib — no new
|
||||
pip packages required.
|
||||
|
||||
Usage
|
||||
-----
|
||||
from web.login_guard import get_guard, get_real_ip
|
||||
|
||||
@router.post("/login")
|
||||
async def login(request: Request, body: LoginRequest, response: Response):
|
||||
ip = get_real_ip(request)
|
||||
get_guard().check(ip) # raises 429 if locked out
|
||||
...
|
||||
if failure:
|
||||
get_guard().record_failure(ip)
|
||||
raise HTTPException(401, ...)
|
||||
get_guard().record_success(ip)
|
||||
...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_real_ip(request: Request) -> str:
|
||||
"""Return the real client IP, honouring Cloudflare and common proxy headers.
|
||||
|
||||
Priority:
|
||||
1. ``CF-Connecting-IP`` (set by Cloudflare)
|
||||
2. ``X-Real-IP`` (set by nginx/traefik)
|
||||
3. ``request.client.host`` (direct connection fallback)
|
||||
"""
|
||||
cf_ip = request.headers.get("CF-Connecting-IP", "").strip()
|
||||
if cf_ip:
|
||||
return cf_ip
|
||||
real_ip = request.headers.get("X-Real-IP", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
class BruteForceGuard:
|
||||
"""Rolling-window failure counter with automatic IP bans.
|
||||
|
||||
All state is in-process memory. A restart clears all bans and counters,
|
||||
which is acceptable — a brief restart already provides a natural backoff
|
||||
for a legitimate attacker.
|
||||
"""
|
||||
|
||||
WINDOW_SECS = 600 # rolling window: 10 minutes
|
||||
MAX_FAILURES = 10 # max failures before ban
|
||||
BAN_SECS = 3600 # ban duration: 1 hour
|
||||
|
||||
def __init__(self) -> None:
|
||||
# ip → list of failure timestamps (epoch floats)
|
||||
self._failures: Dict[str, List[float]] = defaultdict(list)
|
||||
# ip → ban expiry timestamp
|
||||
self._ban_until: Dict[str, float] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def check(self, ip: str) -> None:
|
||||
"""Raise HTTP 429 if the IP is banned or has exceeded the failure threshold.
|
||||
|
||||
Call this *before* doing any credential work so the lockout is
|
||||
evaluated even when the request body is malformed.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# Active ban?
|
||||
ban_expiry = self._ban_until.get(ip, 0)
|
||||
if ban_expiry > now:
|
||||
logger.warning("login_guard: blocked request from banned ip=%s (ban expires in %.0fs)", ip, ban_expiry - now)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Too many failed attempts. Try again later.",
|
||||
)
|
||||
|
||||
# Failure count within rolling window
|
||||
cutoff = now - self.WINDOW_SECS
|
||||
recent = [t for t in self._failures[ip] if t > cutoff]
|
||||
self._failures[ip] = recent # prune stale entries while we're here
|
||||
if len(recent) >= self.MAX_FAILURES:
|
||||
# Threshold just reached — apply ban now
|
||||
self._ban_until[ip] = now + self.BAN_SECS
|
||||
logger.warning(
|
||||
"login_guard: threshold reached, banning ip=%s for %ds",
|
||||
ip, self.BAN_SECS,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Too many failed attempts. Try again later.",
|
||||
)
|
||||
|
||||
def record_failure(self, ip: str) -> None:
|
||||
"""Record a failed login attempt for the given IP."""
|
||||
now = time.time()
|
||||
cutoff = now - self.WINDOW_SECS
|
||||
recent = [t for t in self._failures[ip] if t > cutoff]
|
||||
recent.append(now)
|
||||
self._failures[ip] = recent
|
||||
count = len(recent)
|
||||
logger.warning("login_guard: failure #%d from ip=%s", count, ip)
|
||||
|
||||
if count >= self.MAX_FAILURES:
|
||||
self._ban_until[ip] = now + self.BAN_SECS
|
||||
logger.warning(
|
||||
"login_guard: threshold reached, banning ip=%s for %ds",
|
||||
ip, self.BAN_SECS,
|
||||
)
|
||||
|
||||
def record_success(self, ip: str) -> None:
|
||||
"""Clear failure history and any active ban for the given IP."""
|
||||
self._failures.pop(ip, None)
|
||||
self._ban_until.pop(ip, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_guard: BruteForceGuard | None = None
|
||||
|
||||
|
||||
def get_guard() -> BruteForceGuard:
|
||||
"""Return the shared BruteForceGuard singleton (created on first call)."""
|
||||
global _guard
|
||||
if _guard is None:
|
||||
_guard = BruteForceGuard()
|
||||
return _guard
|
||||
1
web/routers/__init__.py
Normal file
1
web/routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# web.routers package
|
||||
88
web/routers/admin_router.py
Normal file
88
web/routers/admin_router.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""POST /api/admin/login; GET/POST/DELETE /api/admin/tokens"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from web.auth import create_jwt, require_admin
|
||||
from web.deps import get_config
|
||||
from web.login_guard import get_guard, get_real_ip
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_COOKIE = "ttb_session"
|
||||
audit = logging.getLogger("audit")
|
||||
|
||||
|
||||
class AdminLoginRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class CreateTokenRequest(BaseModel):
|
||||
label: str
|
||||
admin: bool = False
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def admin_login(body: AdminLoginRequest, request: Request, response: Response):
|
||||
"""Admin password login → admin JWT cookie."""
|
||||
config = get_config()
|
||||
expected_pw = config.admin_password if config else None
|
||||
|
||||
ip = get_real_ip(request)
|
||||
get_guard().check(ip)
|
||||
|
||||
# Constant-time comparison to prevent timing attacks
|
||||
if not expected_pw or not hmac.compare_digest(body.password, expected_pw):
|
||||
get_guard().record_failure(ip)
|
||||
audit.info("admin.login ip=%s success=False", ip)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Wrong password")
|
||||
|
||||
get_guard().record_success(ip)
|
||||
audit.info("admin.login ip=%s success=True", ip)
|
||||
|
||||
expire_hours = config.web_jwt_expire_hours if config else 8
|
||||
jwt_token = create_jwt("admin", admin=True, expire_hours=expire_hours)
|
||||
response.set_cookie(
|
||||
_COOKIE, jwt_token,
|
||||
httponly=True, secure=True, samesite="strict",
|
||||
max_age=expire_hours * 3600,
|
||||
)
|
||||
return {"label": "admin", "admin": True}
|
||||
|
||||
|
||||
@router.get("/tokens")
|
||||
async def list_tokens(_: dict = Depends(require_admin)):
|
||||
"""List all invite tokens (hashes shown, labels safe)."""
|
||||
from token_store import list_tokens as _list
|
||||
config = get_config()
|
||||
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||
records = _list(token_file)
|
||||
# Don't return hashes to the UI
|
||||
return [{"id": r["id"], "label": r["label"], "admin": r.get("admin", False),
|
||||
"created_at": r.get("created_at")} for r in records]
|
||||
|
||||
|
||||
@router.post("/tokens")
|
||||
async def create_token(body: CreateTokenRequest, _: dict = Depends(require_admin)):
|
||||
"""Create a new invite token. Returns the plaintext token (shown once)."""
|
||||
from token_store import create_token as _create
|
||||
config = get_config()
|
||||
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||
plaintext = _create(body.label, token_file, admin=body.admin)
|
||||
return {"token": plaintext, "label": body.label, "admin": body.admin}
|
||||
|
||||
|
||||
@router.delete("/tokens/{token_id}")
|
||||
async def revoke_token(token_id: str, _: dict = Depends(require_admin)):
|
||||
"""Revoke an invite token by ID."""
|
||||
from token_store import revoke_token as _revoke
|
||||
config = get_config()
|
||||
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||
ok = _revoke(token_id, token_file)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
return {"ok": True}
|
||||
64
web/routers/auth_router.py
Normal file
64
web/routers/auth_router.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""POST /api/auth/login|logout; GET /api/auth/me"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from web.auth import create_jwt, require_auth
|
||||
from web.deps import get_config
|
||||
from web.login_guard import get_guard, get_real_ip
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_COOKIE = "ttb_session"
|
||||
audit = logging.getLogger("audit")
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(body: LoginRequest, request: Request, response: Response):
|
||||
"""Exchange an invite token for a JWT session cookie."""
|
||||
from token_store import verify_token
|
||||
config = get_config()
|
||||
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||
expire_hours = config.web_jwt_expire_hours if config else 8
|
||||
|
||||
ip = get_real_ip(request)
|
||||
get_guard().check(ip)
|
||||
|
||||
record = verify_token(body.token, token_file)
|
||||
if record is None:
|
||||
get_guard().record_failure(ip)
|
||||
audit.info("auth.login ip=%s success=False", ip)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
|
||||
label: str = record["label"]
|
||||
admin: bool = record.get("admin", False)
|
||||
get_guard().record_success(ip)
|
||||
audit.info("auth.login ip=%s success=True label=%s", ip, label)
|
||||
|
||||
jwt_token = create_jwt(label, admin=admin, expire_hours=expire_hours)
|
||||
response.set_cookie(
|
||||
_COOKIE, jwt_token,
|
||||
httponly=True, secure=True, samesite="strict",
|
||||
max_age=expire_hours * 3600,
|
||||
)
|
||||
return {"label": label, "admin": admin}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(response: Response):
|
||||
"""Clear the session cookie."""
|
||||
response.delete_cookie(_COOKIE)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def me(user: dict = Depends(require_auth)):
|
||||
"""Return current user info."""
|
||||
return {"label": user["sub"], "admin": user.get("admin", False)}
|
||||
255
web/routers/generate_router.py
Normal file
255
web/routers/generate_router.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""POST /api/generate and /api/workflow-gen"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_comfy, get_config, get_user_registry
|
||||
from web.ws_bus import get_bus
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
overrides: Optional[Dict[str, Any]] = None # extra per-request overrides
|
||||
|
||||
|
||||
class WorkflowGenRequest(BaseModel):
|
||||
count: int = 1
|
||||
overrides: Optional[Dict[str, Any]] = None # per-request overrides (merged with state)
|
||||
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
|
||||
"""Submit a prompt-based generation to ComfyUI."""
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI client not available")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
bus = get_bus()
|
||||
registry = get_user_registry()
|
||||
|
||||
# Temporary seed override from request
|
||||
if body.overrides and "seed" in body.overrides:
|
||||
seed_override = body.overrides["seed"]
|
||||
elif registry:
|
||||
seed_override = registry.get_state_manager(user_label).get_seed()
|
||||
else:
|
||||
seed_override = comfy.state_manager.get_seed()
|
||||
|
||||
overrides_for_gen = {"prompt": body.prompt}
|
||||
if body.negative_prompt:
|
||||
overrides_for_gen["negative_prompt"] = body.negative_prompt
|
||||
if seed_override is not None:
|
||||
overrides_for_gen["seed"] = seed_override
|
||||
|
||||
# Also apply any extra per-request overrides
|
||||
if body.overrides:
|
||||
overrides_for_gen.update(body.overrides)
|
||||
|
||||
# Get queue position estimate
|
||||
depth = await comfy.get_queue_depth()
|
||||
|
||||
# Start generation as background task so we can return the prompt_id immediately
|
||||
prompt_id_holder: list = []
|
||||
|
||||
async def _run():
|
||||
# Use the user's own workflow template
|
||||
if registry:
|
||||
template = registry.get_workflow_template(user_label)
|
||||
else:
|
||||
template = comfy.workflow_manager.get_workflow_template()
|
||||
if not template:
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": None, "error": "No workflow template loaded"
|
||||
})
|
||||
return
|
||||
import uuid
|
||||
pid = str(uuid.uuid4())
|
||||
prompt_id_holder.append(pid)
|
||||
|
||||
def on_progress(node, pid_):
|
||||
asyncio.create_task(bus.broadcast("node_executing", {
|
||||
"node": node, "prompt_id": pid_
|
||||
}))
|
||||
|
||||
workflow, applied = comfy.inspector.inject_overrides(template, overrides_for_gen)
|
||||
seed_used = applied.get("seed")
|
||||
comfy.last_seed = seed_used
|
||||
|
||||
try:
|
||||
images, videos = await comfy._general_generate(workflow, pid, on_progress)
|
||||
except Exception as exc:
|
||||
logger.exception("Generation error for prompt %s", pid)
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": pid, "error": str(exc)
|
||||
})
|
||||
return
|
||||
|
||||
comfy.last_prompt_id = pid
|
||||
comfy.total_generated += 1
|
||||
|
||||
# Persist to DB before flush_pending deletes local files
|
||||
config = get_config()
|
||||
try:
|
||||
from generation_db import record_generation, record_file
|
||||
gen_id = record_generation(pid, "web", user_label, overrides_for_gen, seed_used)
|
||||
for i, img_data in enumerate(images):
|
||||
record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||
if config and videos:
|
||||
for vid in videos:
|
||||
vsub = vid.get("video_subfolder", "")
|
||||
vname = vid.get("video_name", "")
|
||||
vpath = (
|
||||
Path(config.comfy_output_path) / vsub / vname
|
||||
if vsub
|
||||
else Path(config.comfy_output_path) / vname
|
||||
)
|
||||
try:
|
||||
record_file(gen_id, vname, vpath.read_bytes())
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to record generation to DB: %s", exc)
|
||||
|
||||
# Flush auto-upload
|
||||
if config:
|
||||
from media_uploader import flush_pending
|
||||
asyncio.create_task(flush_pending(
|
||||
Path(config.comfy_output_path),
|
||||
config.media_upload_user,
|
||||
config.media_upload_pass,
|
||||
))
|
||||
|
||||
await bus.broadcast("queue_update", {
|
||||
"prompt_id": pid,
|
||||
"status": "complete",
|
||||
})
|
||||
await bus.broadcast_to_user(user_label, "generation_complete", {
|
||||
"prompt_id": pid,
|
||||
"seed": seed_used,
|
||||
"image_count": len(images),
|
||||
"video_count": len(videos),
|
||||
})
|
||||
|
||||
asyncio.create_task(_run())
|
||||
|
||||
return {
|
||||
"queued": True,
|
||||
"queue_position": depth + 1,
|
||||
"message": "Generation submitted to ComfyUI",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/workflow-gen")
|
||||
async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_auth)):
|
||||
"""Submit workflow-based generation(s) to ComfyUI."""
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI client not available")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
bus = get_bus()
|
||||
registry = get_user_registry()
|
||||
count = max(1, min(body.count, 20)) # cap at 20
|
||||
|
||||
async def _run_one():
|
||||
# Use the user's own state and template
|
||||
if registry:
|
||||
user_sm = registry.get_state_manager(user_label)
|
||||
user_template = registry.get_workflow_template(user_label)
|
||||
else:
|
||||
user_sm = comfy.state_manager
|
||||
user_template = comfy.workflow_manager.get_workflow_template()
|
||||
|
||||
if not user_template:
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": None, "error": "No workflow template loaded"
|
||||
})
|
||||
return
|
||||
|
||||
overrides = user_sm.get_overrides()
|
||||
if body.overrides:
|
||||
overrides = {**overrides, **body.overrides}
|
||||
|
||||
import uuid
|
||||
pid = str(uuid.uuid4())
|
||||
|
||||
def on_progress(node, pid_):
|
||||
asyncio.create_task(bus.broadcast("node_executing", {
|
||||
"node": node, "prompt_id": pid_
|
||||
}))
|
||||
|
||||
workflow, applied = comfy.inspector.inject_overrides(user_template, overrides)
|
||||
seed_used = applied.get("seed")
|
||||
comfy.last_seed = seed_used
|
||||
|
||||
try:
|
||||
images, videos = await comfy._general_generate(workflow, pid, on_progress)
|
||||
except Exception as exc:
|
||||
logger.exception("Workflow gen error")
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": None, "error": str(exc)
|
||||
})
|
||||
return
|
||||
|
||||
comfy.last_prompt_id = pid
|
||||
comfy.total_generated += 1
|
||||
|
||||
config = get_config()
|
||||
try:
|
||||
from generation_db import record_generation, record_file
|
||||
gen_id = record_generation(pid, "web", user_label, overrides, seed_used)
|
||||
for i, img_data in enumerate(images):
|
||||
record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||
if config and videos:
|
||||
for vid in videos:
|
||||
vsub = vid.get("video_subfolder", "")
|
||||
vname = vid.get("video_name", "")
|
||||
vpath = (
|
||||
Path(config.comfy_output_path) / vsub / vname
|
||||
if vsub
|
||||
else Path(config.comfy_output_path) / vname
|
||||
)
|
||||
try:
|
||||
record_file(gen_id, vname, vpath.read_bytes())
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to record generation to DB: %s", exc)
|
||||
|
||||
if config:
|
||||
from media_uploader import flush_pending
|
||||
asyncio.create_task(flush_pending(
|
||||
Path(config.comfy_output_path),
|
||||
config.media_upload_user,
|
||||
config.media_upload_pass,
|
||||
))
|
||||
|
||||
await bus.broadcast("queue_update", {"prompt_id": pid, "status": "complete"})
|
||||
await bus.broadcast_to_user(user_label, "generation_complete", {
|
||||
"prompt_id": pid,
|
||||
"seed": seed_used,
|
||||
"image_count": len(images),
|
||||
"video_count": len(videos),
|
||||
})
|
||||
|
||||
depth = await comfy.get_queue_depth()
|
||||
for _ in range(count):
|
||||
asyncio.create_task(_run_one())
|
||||
|
||||
return {
|
||||
"queued": True,
|
||||
"count": count,
|
||||
"queue_position": depth + 1,
|
||||
}
|
||||
143
web/routers/history_router.py
Normal file
143
web/routers/history_router.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""GET /api/history; GET /api/history/{prompt_id}/images; GET /api/history/{prompt_id}/file/{filename};
|
||||
POST /api/history/{prompt_id}/share; DELETE /api/history/{prompt_id}/share"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
from web.auth import require_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _assert_owns(prompt_id: str, user: dict) -> None:
|
||||
"""Raise 404 if the generation doesn't exist or doesn't belong to the user.
|
||||
|
||||
Returning the same 404 for both cases prevents leaking whether a
|
||||
prompt_id exists to users who don't own it. Admins bypass this check.
|
||||
"""
|
||||
if user.get("admin"):
|
||||
return
|
||||
from generation_db import get_generation
|
||||
gen = get_generation(prompt_id)
|
||||
if gen is None or gen["user_label"] != user["sub"]:
|
||||
raise HTTPException(404, "Not found")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_history(
|
||||
user: dict = Depends(require_auth),
|
||||
q: Optional[str] = Query(None, description="Keyword to search in overrides JSON"),
|
||||
):
|
||||
"""Return generation history. Admins see all; regular users see only their own.
|
||||
Pass ?q=keyword to filter by prompt text or any override field."""
|
||||
from generation_db import (
|
||||
get_history as db_get_history,
|
||||
get_history_for_user,
|
||||
search_history,
|
||||
search_history_for_user,
|
||||
)
|
||||
if q and q.strip():
|
||||
if user.get("admin"):
|
||||
return {"history": search_history(q.strip(), limit=50)}
|
||||
return {"history": search_history_for_user(user["sub"], q.strip(), limit=50)}
|
||||
if user.get("admin"):
|
||||
return {"history": db_get_history(limit=50)}
|
||||
return {"history": get_history_for_user(user["sub"], limit=50)}
|
||||
|
||||
|
||||
@router.get("/{prompt_id}/images")
|
||||
async def get_history_images(prompt_id: str, user: dict = Depends(require_auth)):
|
||||
"""
|
||||
Fetch output files for a past generation.
|
||||
|
||||
Returns base64-encoded blobs from the local SQLite DB — works even after
|
||||
``flush_pending`` has deleted the files from disk.
|
||||
"""
|
||||
_assert_owns(prompt_id, user)
|
||||
from generation_db import get_files
|
||||
files = get_files(prompt_id)
|
||||
if not files:
|
||||
raise HTTPException(404, f"No files found for prompt_id {prompt_id!r}")
|
||||
return {
|
||||
"prompt_id": prompt_id,
|
||||
"images": [
|
||||
{
|
||||
"filename": f["filename"],
|
||||
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
|
||||
"mime_type": f["mime_type"],
|
||||
}
|
||||
for f in files
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{prompt_id}/file/{filename}")
|
||||
async def get_history_file(
|
||||
prompt_id: str,
|
||||
filename: str,
|
||||
request: Request,
|
||||
user: dict = Depends(require_auth),
|
||||
):
|
||||
"""Stream a single output file, with HTTP range request support for video seeking."""
|
||||
_assert_owns(prompt_id, user)
|
||||
from generation_db import get_files
|
||||
files = get_files(prompt_id)
|
||||
matched = next((f for f in files if f["filename"] == filename), None)
|
||||
if matched is None:
|
||||
raise HTTPException(404, f"File {filename!r} not found for prompt_id {prompt_id!r}")
|
||||
|
||||
data: bytes = matched["data"]
|
||||
mime: str = matched["mime_type"]
|
||||
total = len(data)
|
||||
|
||||
range_header = request.headers.get("range")
|
||||
if range_header:
|
||||
range_val = range_header.replace("bytes=", "")
|
||||
start_str, _, end_str = range_val.partition("-")
|
||||
start = int(start_str) if start_str else 0
|
||||
end = int(end_str) if end_str else total - 1
|
||||
end = min(end, total - 1)
|
||||
chunk = data[start : end + 1]
|
||||
return Response(
|
||||
content=chunk,
|
||||
status_code=206,
|
||||
media_type=mime,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{total}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(len(chunk)),
|
||||
},
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=data,
|
||||
media_type=mime,
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(total),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{prompt_id}/share")
|
||||
async def create_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
|
||||
"""Create a share token for a generation. Only the owner may share."""
|
||||
# Use the same 404-for-everything helper to avoid leaking prompt_id existence
|
||||
_assert_owns(prompt_id, user)
|
||||
from generation_db import create_share as db_create_share
|
||||
token = db_create_share(prompt_id, user["sub"])
|
||||
return {"share_token": token}
|
||||
|
||||
|
||||
@router.delete("/{prompt_id}/share")
|
||||
async def revoke_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
|
||||
"""Revoke a share token for a generation. Only the owner may revoke."""
|
||||
from generation_db import revoke_share as db_revoke_share
|
||||
deleted = db_revoke_share(prompt_id, user["sub"])
|
||||
if not deleted:
|
||||
raise HTTPException(404, "No active share found for this generation")
|
||||
return {"ok": True}
|
||||
194
web/routers/inputs_router.py
Normal file
194
web/routers/inputs_router.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""GET/POST/DELETE /api/inputs; GET /api/inputs/{id}/image; POST /api/inputs/{id}/activate"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_config, get_user_registry
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_inputs(_: dict = Depends(require_auth)):
|
||||
"""List all input images (Discord + web uploads)."""
|
||||
from input_image_db import get_all_images
|
||||
rows = get_all_images()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def upload_input(
|
||||
file: UploadFile = File(...),
|
||||
slot_key: Optional[str] = Form(default=None),
|
||||
user: dict = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
Upload an input image.
|
||||
|
||||
Stores image bytes directly in SQLite. If *slot_key* is provided the
|
||||
image is immediately activated for that slot (writes to ComfyUI input
|
||||
folder and updates the user's state override).
|
||||
|
||||
The physical slot file uses a namespaced key ``<user_label>_<slot_key>``
|
||||
so concurrent users each get their own active image file.
|
||||
"""
|
||||
config = get_config()
|
||||
if config is None:
|
||||
raise HTTPException(503, "Config not available")
|
||||
|
||||
data = await file.read()
|
||||
filename = file.filename or "upload.png"
|
||||
|
||||
from input_image_db import upsert_image, activate_image_for_slot
|
||||
row_id = upsert_image(
|
||||
original_message_id=0, # sentinel for web uploads
|
||||
bot_reply_id=0,
|
||||
channel_id=0,
|
||||
filename=filename,
|
||||
image_data=data,
|
||||
)
|
||||
|
||||
activated_filename: str | None = None
|
||||
if slot_key:
|
||||
user_label: str = user["sub"]
|
||||
namespaced_key = f"{user_label}_{slot_key}"
|
||||
activated_filename = activate_image_for_slot(
|
||||
row_id, namespaced_key, config.comfy_input_path
|
||||
)
|
||||
registry = get_user_registry()
|
||||
if registry:
|
||||
registry.get_state_manager(user_label).set_override(slot_key, activated_filename)
|
||||
else:
|
||||
from web.deps import get_comfy
|
||||
comfy = get_comfy()
|
||||
if comfy:
|
||||
comfy.state_manager.set_override(slot_key, activated_filename)
|
||||
|
||||
return {"id": row_id, "filename": filename, "slot_key": slot_key, "activated_filename": activated_filename}
|
||||
|
||||
|
||||
@router.post("/{row_id}/activate")
|
||||
async def activate_input(
|
||||
row_id: int,
|
||||
slot_key: str = Body(default="input_image", embed=True),
|
||||
user: dict = Depends(require_auth),
|
||||
):
|
||||
"""Write the stored image to the ComfyUI input folder and set the user's slot override."""
|
||||
config = get_config()
|
||||
if config is None:
|
||||
raise HTTPException(503, "Config not available")
|
||||
|
||||
from input_image_db import get_image, activate_image_for_slot
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
namespaced_key = f"{user_label}_{slot_key}"
|
||||
|
||||
try:
|
||||
filename = activate_image_for_slot(row_id, namespaced_key, config.comfy_input_path)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(409, str(exc))
|
||||
|
||||
registry = get_user_registry()
|
||||
if registry:
|
||||
registry.get_state_manager(user_label).set_override(slot_key, filename)
|
||||
else:
|
||||
from web.deps import get_comfy
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "State manager not available")
|
||||
comfy.state_manager.set_override(slot_key, filename)
|
||||
|
||||
return {"ok": True, "slot_key": slot_key, "filename": filename}
|
||||
|
||||
|
||||
@router.delete("/{row_id}")
|
||||
async def delete_input(row_id: int, _: dict = Depends(require_auth)):
|
||||
"""Delete an input image record (and its active slot file if applicable)."""
|
||||
from input_image_db import get_image, delete_image
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
config = get_config()
|
||||
delete_image(row_id, comfy_input_path=config.comfy_input_path if config else None)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/{row_id}/image")
|
||||
async def get_input_image(row_id: int, _: dict = Depends(require_auth)):
|
||||
"""Serve the raw image bytes stored in the database for a given input image row."""
|
||||
from input_image_db import get_image, get_image_data
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
data = get_image_data(row_id)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
||||
|
||||
mime, _ = mimetypes.guess_type(row["filename"])
|
||||
return Response(content=data, media_type=mime or "application/octet-stream")
|
||||
|
||||
|
||||
def _pil_resize_response(data: bytes, filename: str, max_size: int, quality: int) -> Response:
|
||||
"""Resize image bytes with Pillow and return a JPEG Response. Raises on failure."""
|
||||
import io
|
||||
from PIL import Image as _PIL
|
||||
img = _PIL.open(io.BytesIO(data))
|
||||
img.thumbnail((max_size, max_size), _PIL.LANCZOS)
|
||||
buf = io.BytesIO()
|
||||
img.convert("RGB").save(buf, "JPEG", quality=quality, optimize=True)
|
||||
return Response(
|
||||
content=buf.getvalue(),
|
||||
media_type="image/jpeg",
|
||||
headers={"Cache-Control": "public, max-age=86400"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{row_id}/thumb")
|
||||
async def get_input_thumb(row_id: int, _: dict = Depends(require_auth)):
|
||||
"""Serve a small compressed thumbnail (max 200 px, JPEG 65 %) for fast previews."""
|
||||
from input_image_db import get_image, get_image_data
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
data = get_image_data(row_id)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
||||
|
||||
try:
|
||||
return _pil_resize_response(data, row["filename"], max_size=200, quality=65)
|
||||
except Exception:
|
||||
mime, _ = mimetypes.guess_type(row["filename"])
|
||||
return Response(content=data, media_type=mime or "application/octet-stream")
|
||||
|
||||
|
||||
@router.get("/{row_id}/mid")
|
||||
async def get_input_mid(row_id: int, _: dict = Depends(require_auth)):
|
||||
"""Serve a medium compressed image (max 800 px, JPEG 80 %) for progressive loading."""
|
||||
from input_image_db import get_image, get_image_data
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
data = get_image_data(row_id)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
||||
|
||||
try:
|
||||
return _pil_resize_response(data, row["filename"], max_size=800, quality=80)
|
||||
except Exception:
|
||||
mime, _ = mimetypes.guess_type(row["filename"])
|
||||
return Response(content=data, media_type=mime or "application/octet-stream")
|
||||
153
web/routers/presets_router.py
Normal file
153
web/routers/presets_router.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""CRUD for workflow presets via /api/presets"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_comfy, get_user_registry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SavePresetRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class SaveFromHistoryRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
def _get_pm():
|
||||
from web.deps import get_bot
|
||||
bot = get_bot()
|
||||
pm = getattr(bot, "preset_manager", None) if bot else None
|
||||
if pm is None:
|
||||
from preset_manager import PresetManager
|
||||
pm = PresetManager()
|
||||
return pm
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_presets(_: dict = Depends(require_auth)):
|
||||
pm = _get_pm()
|
||||
return {"presets": pm.list_preset_details()}
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def save_preset(body: SavePresetRequest, user: dict = Depends(require_auth)):
|
||||
"""Capture the user's overrides + workflow template as a named preset."""
|
||||
user_label: str = user["sub"]
|
||||
registry = get_user_registry()
|
||||
pm = _get_pm()
|
||||
|
||||
if registry:
|
||||
workflow_template = registry.get_workflow_template(user_label)
|
||||
overrides = registry.get_state_manager(user_label).get_overrides()
|
||||
else:
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI not available")
|
||||
workflow_template = comfy.get_workflow_template()
|
||||
overrides = comfy.state_manager.get_overrides()
|
||||
|
||||
try:
|
||||
pm.save(body.name, workflow_template, overrides, owner=user_label, description=body.description)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
return {"ok": True, "name": body.name}
|
||||
|
||||
|
||||
@router.get("/{name}")
|
||||
async def get_preset(name: str, _: dict = Depends(require_auth)):
|
||||
pm = _get_pm()
|
||||
data = pm.load(name)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Preset not found")
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/{name}/load")
|
||||
async def load_preset(name: str, user: dict = Depends(require_auth)):
|
||||
"""Restore overrides (and optionally workflow template) from a preset into the user's state."""
|
||||
pm = _get_pm()
|
||||
data = pm.load(name)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Preset not found")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
registry = get_user_registry()
|
||||
|
||||
if registry:
|
||||
wf = data.get("workflow")
|
||||
if wf:
|
||||
registry.set_workflow(user_label, wf, name)
|
||||
else:
|
||||
# No workflow in preset — just clear overrides and restore state
|
||||
registry.get_state_manager(user_label).clear_overrides()
|
||||
state = data.get("state", {})
|
||||
sm = registry.get_state_manager(user_label)
|
||||
for k, v in state.items():
|
||||
if v is not None:
|
||||
sm.set_override(k, v)
|
||||
else:
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI not available")
|
||||
comfy.state_manager.clear_overrides()
|
||||
state = data.get("state", {})
|
||||
for k, v in state.items():
|
||||
if v is not None:
|
||||
comfy.state_manager.set_override(k, v)
|
||||
wf = data.get("workflow")
|
||||
if wf:
|
||||
comfy.workflow_manager.set_workflow_template(wf)
|
||||
|
||||
return {"ok": True, "name": name, "overrides_restored": list(data.get("state", {}).keys())}
|
||||
|
||||
|
||||
@router.delete("/{name}")
|
||||
async def delete_preset(name: str, user: dict = Depends(require_auth)):
|
||||
pm = _get_pm()
|
||||
data = pm.load(name)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Preset not found")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
is_admin = user.get("admin") is True
|
||||
owner = data.get("owner")
|
||||
if owner is not None and owner != user_label and not is_admin:
|
||||
raise HTTPException(403, "You do not have permission to delete this preset")
|
||||
|
||||
pm.delete(name)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/from-history/{prompt_id}")
|
||||
async def save_preset_from_history(
|
||||
prompt_id: str,
|
||||
body: SaveFromHistoryRequest,
|
||||
user: dict = Depends(require_auth),
|
||||
):
|
||||
"""Create a preset from a past generation's overrides."""
|
||||
from generation_db import get_generation_full
|
||||
|
||||
gen = get_generation_full(prompt_id)
|
||||
if gen is None:
|
||||
raise HTTPException(404, "Generation not found")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
is_admin = user.get("admin") is True
|
||||
if not is_admin and gen.get("user_label") != user_label:
|
||||
raise HTTPException(404, "Generation not found")
|
||||
|
||||
pm = _get_pm()
|
||||
try:
|
||||
pm.save(body.name, None, gen["overrides"], owner=user_label, description=body.description)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
return {"ok": True, "name": body.name}
|
||||
90
web/routers/server_router.py
Normal file
90
web/routers/server_router.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""GET/POST /api/server/{action}; GET /api/logs/tail"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_config, get_comfy
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/server/status")
|
||||
async def server_status(_: dict = Depends(require_auth)):
|
||||
"""Return NSSM service state and HTTP health."""
|
||||
config = get_config()
|
||||
if config is None:
|
||||
raise HTTPException(503, "Config not available")
|
||||
from commands.server import get_service_state
|
||||
import asyncio
|
||||
|
||||
async def _false():
|
||||
return False
|
||||
|
||||
comfy = get_comfy()
|
||||
service_state, http_ok = await asyncio.gather(
|
||||
get_service_state(config.comfy_service_name),
|
||||
comfy.check_connection() if comfy else _false(),
|
||||
)
|
||||
return {"service_state": service_state, "http_reachable": http_ok}
|
||||
|
||||
|
||||
@router.post("/server/{action}")
|
||||
async def server_action(action: str, _: dict = Depends(require_auth)):
|
||||
"""Control the ComfyUI service: start | stop | restart | install | uninstall"""
|
||||
config = get_config()
|
||||
if config is None:
|
||||
raise HTTPException(503, "Config not available")
|
||||
|
||||
valid_actions = {"start", "stop", "restart", "install", "uninstall"}
|
||||
if action not in valid_actions:
|
||||
raise HTTPException(400, f"Invalid action '{action}'")
|
||||
|
||||
from commands.server import _nssm, _install_service
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
if action == "install":
|
||||
ok, msg = await _install_service(config)
|
||||
if not ok:
|
||||
raise HTTPException(500, msg)
|
||||
elif action == "uninstall":
|
||||
await _nssm("stop", config.comfy_service_name)
|
||||
await _nssm("remove", config.comfy_service_name, "confirm")
|
||||
elif action == "start":
|
||||
await _nssm("start", config.comfy_service_name)
|
||||
elif action == "stop":
|
||||
await _nssm("stop", config.comfy_service_name)
|
||||
elif action == "restart":
|
||||
await _nssm("restart", config.comfy_service_name)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, str(exc))
|
||||
|
||||
return {"ok": True, "action": action}
|
||||
|
||||
|
||||
@router.get("/logs/tail")
|
||||
async def tail_logs(lines: int = 100, _: dict = Depends(require_auth)):
|
||||
"""Tail the ComfyUI log file."""
|
||||
config = get_config()
|
||||
if config is None or not config.comfy_log_dir:
|
||||
raise HTTPException(503, "Log directory not configured")
|
||||
|
||||
log_dir = Path(config.comfy_log_dir)
|
||||
log_file = log_dir / "comfyui.log"
|
||||
if not log_file.exists():
|
||||
return {"lines": []}
|
||||
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
|
||||
all_lines = f.readlines()
|
||||
tail = all_lines[-min(lines, len(all_lines)):]
|
||||
return {"lines": [ln.rstrip("\n") for ln in tail]}
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, str(exc))
|
||||
85
web/routers/share_router.py
Normal file
85
web/routers/share_router.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""GET /api/share/{token}; GET /api/share/{token}/file/{filename}"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
from web.auth import require_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{token}")
|
||||
async def get_share(token: str, _: dict = Depends(require_auth)):
|
||||
"""Fetch share metadata and images. Any authenticated user may view a valid share link."""
|
||||
from generation_db import get_share_by_token, get_files
|
||||
gen = get_share_by_token(token)
|
||||
if gen is None:
|
||||
raise HTTPException(404, "Share not found or revoked")
|
||||
files = get_files(gen["prompt_id"])
|
||||
return {
|
||||
"prompt_id": gen["prompt_id"],
|
||||
"created_at": gen["created_at"],
|
||||
"overrides": gen["overrides"],
|
||||
"seed": gen["seed"],
|
||||
"images": [
|
||||
{
|
||||
"filename": f["filename"],
|
||||
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
|
||||
"mime_type": f["mime_type"],
|
||||
}
|
||||
for f in files
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{token}/file/{filename}")
|
||||
async def get_share_file(
|
||||
token: str,
|
||||
filename: str,
|
||||
request: Request,
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Stream a single output file via share token, with HTTP range support for video seeking."""
|
||||
from generation_db import get_share_by_token, get_files
|
||||
gen = get_share_by_token(token)
|
||||
if gen is None:
|
||||
raise HTTPException(404, "Share not found or revoked")
|
||||
files = get_files(gen["prompt_id"])
|
||||
matched = next((f for f in files if f["filename"] == filename), None)
|
||||
if matched is None:
|
||||
raise HTTPException(404, f"File {filename!r} not found")
|
||||
|
||||
data: bytes = matched["data"]
|
||||
mime: str = matched["mime_type"]
|
||||
total = len(data)
|
||||
|
||||
range_header = request.headers.get("range")
|
||||
if range_header:
|
||||
range_val = range_header.replace("bytes=", "")
|
||||
start_str, _, end_str = range_val.partition("-")
|
||||
start = int(start_str) if start_str else 0
|
||||
end = int(end_str) if end_str else total - 1
|
||||
end = min(end, total - 1)
|
||||
chunk = data[start : end + 1]
|
||||
return Response(
|
||||
content=chunk,
|
||||
status_code=206,
|
||||
media_type=mime,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{total}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(len(chunk)),
|
||||
},
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=data,
|
||||
media_type=mime,
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(total),
|
||||
},
|
||||
)
|
||||
53
web/routers/state_router.py
Normal file
53
web/routers/state_router.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""GET/PUT /api/state; DELETE /api/state/{key}"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_config, get_user_registry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _get_user_sm(user: dict):
|
||||
"""Return the per-user WorkflowStateManager, raising 503 if unavailable."""
|
||||
registry = get_user_registry()
|
||||
if registry is None:
|
||||
raise HTTPException(503, "State manager not available")
|
||||
return registry.get_state_manager(user["sub"])
|
||||
|
||||
|
||||
@router.get("/state")
|
||||
async def get_state(user: dict = Depends(require_auth)):
|
||||
"""Return all current overrides for the authenticated user."""
|
||||
sm = _get_user_sm(user)
|
||||
return sm.get_overrides()
|
||||
|
||||
|
||||
@router.put("/state")
|
||||
async def put_state(body: Dict[str, Any], user: dict = Depends(require_auth)):
|
||||
"""Merge override values. Pass ``null`` as a value to delete a key."""
|
||||
sm = _get_user_sm(user)
|
||||
for key, value in body.items():
|
||||
if value is None:
|
||||
sm.delete_override(key)
|
||||
else:
|
||||
sm.set_override(key, value)
|
||||
return sm.get_overrides()
|
||||
|
||||
|
||||
@router.delete("/state/{key}")
|
||||
async def delete_state_key(key: str, user: dict = Depends(require_auth)):
|
||||
"""Remove a single override key, and clean up any associated slot file."""
|
||||
sm = _get_user_sm(user)
|
||||
sm.delete_override(key)
|
||||
|
||||
config = get_config()
|
||||
if config:
|
||||
from input_image_db import deactivate_image_slot
|
||||
user_label: str = user["sub"]
|
||||
deactivate_image_slot(f"{user_label}_{key}", config.comfy_input_path)
|
||||
|
||||
return {"ok": True, "key": key}
|
||||
74
web/routers/status_router.py
Normal file
74
web/routers/status_router.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""GET /api/status — polling fallback for clients that can't use WebSocket"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_bot, get_comfy, get_config
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_status(_: dict = Depends(require_auth)):
|
||||
"""Return a full status snapshot."""
|
||||
bot = get_bot()
|
||||
comfy = get_comfy()
|
||||
config = get_config()
|
||||
|
||||
snap: dict = {}
|
||||
|
||||
if bot is not None:
|
||||
import datetime as _dt
|
||||
lat = bot.latency
|
||||
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
|
||||
start = getattr(bot, "start_time", None)
|
||||
uptime = ""
|
||||
if start:
|
||||
delta = _dt.datetime.now(_dt.timezone.utc) - start
|
||||
total = int(delta.total_seconds())
|
||||
h, rem = divmod(total, 3600)
|
||||
m, s = divmod(rem, 60)
|
||||
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
|
||||
snap["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
|
||||
|
||||
if comfy is not None:
|
||||
q_task = asyncio.create_task(comfy.get_comfy_queue())
|
||||
conn_task = asyncio.create_task(comfy.check_connection())
|
||||
q, reachable = await asyncio.gather(q_task, conn_task)
|
||||
|
||||
pending = len(q.get("queue_pending", [])) if q else 0
|
||||
running = len(q.get("queue_running", [])) if q else 0
|
||||
wm = getattr(comfy, "workflow_manager", None)
|
||||
wf_loaded = wm is not None and wm.get_workflow_template() is not None
|
||||
|
||||
snap["comfy"] = {
|
||||
"server": comfy.server_address,
|
||||
"reachable": reachable,
|
||||
"queue_pending": pending,
|
||||
"queue_running": running,
|
||||
"workflow_loaded": wf_loaded,
|
||||
"last_seed": comfy.last_seed,
|
||||
"total_generated": comfy.total_generated,
|
||||
}
|
||||
snap["overrides"] = comfy.state_manager.get_overrides()
|
||||
|
||||
if config is not None:
|
||||
from commands.server import get_service_state
|
||||
service_state = await get_service_state(config.comfy_service_name)
|
||||
snap["service"] = {"state": service_state}
|
||||
|
||||
try:
|
||||
from media_uploader import get_stats as us_fn, is_running as ur_fn
|
||||
us = us_fn()
|
||||
snap["upload"] = {
|
||||
"configured": bool(config.media_upload_user),
|
||||
"running": ur_fn(),
|
||||
"total_ok": us.total_ok,
|
||||
"total_fail": us.total_fail,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return snap
|
||||
175
web/routers/workflow_router.py
Normal file
175
web/routers/workflow_router.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
GET /api/workflow — current workflow info
|
||||
GET /api/workflow/inputs — dynamic NodeInput list
|
||||
GET /api/workflow/files — list files in workflows/
|
||||
POST /api/workflow/upload — upload a workflow JSON
|
||||
POST /api/workflow/load — load a workflow from workflows/
|
||||
GET /api/workflow/models?type=checkpoints|loras — available models (60s TTL cache)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_comfy, get_config, get_inspector, get_user_registry
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
_WORKFLOWS_DIR = _PROJECT_ROOT / "workflows"
|
||||
|
||||
# Simple in-memory TTL cache for models
|
||||
_models_cache: dict = {}
|
||||
_MODELS_TTL = 60.0
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_workflow(user: dict = Depends(require_auth)):
|
||||
"""Return basic info about the currently loaded workflow."""
|
||||
user_label: str = user["sub"]
|
||||
registry = get_user_registry()
|
||||
|
||||
if registry:
|
||||
template = registry.get_workflow_template(user_label)
|
||||
last_wf = registry.get_state_manager(user_label).get_last_workflow_file()
|
||||
else:
|
||||
# Fallback to global state when registry is unavailable
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "Workflow manager not available")
|
||||
template = comfy.workflow_manager.get_workflow_template()
|
||||
last_wf = comfy.state_manager.get_last_workflow_file()
|
||||
|
||||
return {
|
||||
"loaded": template is not None,
|
||||
"node_count": len(template) if template else 0,
|
||||
"last_workflow_file": last_wf,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/inputs")
|
||||
async def get_workflow_inputs(user: dict = Depends(require_auth)):
|
||||
"""Return dynamic NodeInput list (common + advanced) for the current workflow."""
|
||||
user_label: str = user["sub"]
|
||||
inspector = get_inspector()
|
||||
if inspector is None:
|
||||
raise HTTPException(503, "Workflow components not available")
|
||||
|
||||
registry = get_user_registry()
|
||||
if registry:
|
||||
template = registry.get_workflow_template(user_label)
|
||||
overrides = registry.get_state_manager(user_label).get_overrides()
|
||||
else:
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "Workflow components not available")
|
||||
template = comfy.workflow_manager.get_workflow_template()
|
||||
overrides = comfy.state_manager.get_overrides()
|
||||
|
||||
if template is None:
|
||||
return {"common": [], "advanced": []}
|
||||
|
||||
inputs = inspector.inspect(template)
|
||||
result = []
|
||||
for ni in inputs:
|
||||
val = overrides.get(ni.key, ni.current_value)
|
||||
result.append({
|
||||
"key": ni.key,
|
||||
"label": ni.label,
|
||||
"input_type": ni.input_type,
|
||||
"current_value": val,
|
||||
"node_class": ni.node_class,
|
||||
"node_title": ni.node_title,
|
||||
"is_common": ni.is_common,
|
||||
})
|
||||
common = [r for r in result if r["is_common"]]
|
||||
advanced = [r for r in result if not r["is_common"]]
|
||||
return {"common": common, "advanced": advanced}
|
||||
|
||||
|
||||
@router.get("/files")
|
||||
async def list_workflow_files(_: dict = Depends(require_auth)):
|
||||
"""List .json files in the workflows/ folder."""
|
||||
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
files = sorted(p.name for p in _WORKFLOWS_DIR.glob("*.json"))
|
||||
return {"files": files}
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_workflow(
|
||||
file: UploadFile = File(...),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Upload a workflow JSON to the workflows/ folder."""
|
||||
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
filename = file.filename or "workflow.json"
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
data = await file.read()
|
||||
try:
|
||||
json.loads(data) # validate JSON
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(400, f"Invalid JSON: {exc}")
|
||||
dest = _WORKFLOWS_DIR / filename
|
||||
dest.write_bytes(data)
|
||||
return {"ok": True, "filename": filename}
|
||||
|
||||
|
||||
@router.post("/load")
|
||||
async def load_workflow(filename: str = Form(...), user: dict = Depends(require_auth)):
|
||||
"""Load a workflow from the workflows/ folder into the user's isolated state."""
|
||||
wf_path = _WORKFLOWS_DIR / filename
|
||||
if not wf_path.exists():
|
||||
raise HTTPException(404, f"Workflow file '{filename}' not found")
|
||||
try:
|
||||
with open(wf_path, "r", encoding="utf-8") as f:
|
||||
workflow = json.load(f)
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, str(exc))
|
||||
|
||||
registry = get_user_registry()
|
||||
if registry:
|
||||
registry.set_workflow(user["sub"], workflow, filename)
|
||||
else:
|
||||
# Fallback: update global state when registry unavailable
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI not available")
|
||||
comfy.workflow_manager.set_workflow_template(workflow)
|
||||
comfy.state_manager.clear_overrides()
|
||||
comfy.state_manager.set_last_workflow_file(filename)
|
||||
|
||||
inspector = get_inspector()
|
||||
node_count = len(workflow)
|
||||
inputs_count = len(inspector.inspect(workflow)) if inspector else 0
|
||||
return {
|
||||
"ok": True,
|
||||
"filename": filename,
|
||||
"node_count": node_count,
|
||||
"inputs_count": inputs_count,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_models(type: str = "checkpoints", _: dict = Depends(require_auth)):
|
||||
"""Return available model names from ComfyUI (60s TTL cache)."""
|
||||
global _models_cache
|
||||
now = time.time()
|
||||
cache_key = type
|
||||
cached = _models_cache.get(cache_key)
|
||||
if cached and (now - cached["ts"]) < _MODELS_TTL:
|
||||
return {"type": type, "models": cached["models"]}
|
||||
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI not available")
|
||||
models = await comfy.get_models(type)
|
||||
_models_cache[cache_key] = {"models": models, "ts": now}
|
||||
return {"type": type, "models": models}
|
||||
61
web/routers/ws_router.py
Normal file
61
web/routers/ws_router.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""WebSocket /ws?token=<jwt> — real-time event stream"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from web.auth import verify_ws_token
|
||||
from web.ws_bus import get_bus
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: str = ""):
|
||||
"""
|
||||
Authenticate via JWT query param or ttb_session cookie, then stream events from WSBus.
|
||||
|
||||
Events common to all users: status_snapshot, queue_update, node_executing, server_state
|
||||
Events private to submitter: generation_complete, generation_error
|
||||
"""
|
||||
payload = verify_ws_token(token)
|
||||
if payload is None:
|
||||
# Fallback: browsers send cookies automatically with WebSocket connections
|
||||
cookie_token = websocket.cookies.get("ttb_session", "")
|
||||
payload = verify_ws_token(cookie_token)
|
||||
if payload is None:
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
user_label: str = payload.get("sub", "anonymous")
|
||||
bus = get_bus()
|
||||
queue = bus.subscribe(user_label)
|
||||
await websocket.accept()
|
||||
logger.info("WS connected: user=%s", user_label)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for an event from the bus
|
||||
try:
|
||||
frame = await asyncio.wait_for(queue.get(), timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
# Send a keepalive ping
|
||||
try:
|
||||
await websocket.send_text('{"type":"ping"}')
|
||||
except Exception:
|
||||
break
|
||||
continue
|
||||
|
||||
try:
|
||||
await websocket.send_text(frame)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
bus.unsubscribe(user_label, queue)
|
||||
logger.info("WS disconnected: user=%s", user_label)
|
||||
122
web/ws_bus.py
Normal file
122
web/ws_bus.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
web/ws_bus.py
|
||||
=============
|
||||
|
||||
In-process WebSocket event bus.
|
||||
|
||||
All connected web clients share a single WSBus instance. Events are
|
||||
delivered per-user (private results) or to all users (shared status).
|
||||
|
||||
Usage::
|
||||
|
||||
bus = WSBus()
|
||||
|
||||
# Subscribe (returns a queue; caller reads from it)
|
||||
q = bus.subscribe("alice")
|
||||
|
||||
# Broadcast to all
|
||||
await bus.broadcast("status_snapshot", {...})
|
||||
|
||||
# Broadcast to one user (all their open tabs)
|
||||
await bus.broadcast_to_user("alice", "generation_complete", {...})
|
||||
|
||||
# Unsubscribe when WS disconnects
|
||||
bus.unsubscribe("alice", q)
|
||||
|
||||
Event frame format sent on wire:
|
||||
{"type": "event_name", "data": {...}, "ts": 1234567890.123}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WSBus:
|
||||
"""
|
||||
Per-user broadcast bus backed by asyncio queues.
|
||||
|
||||
Thread-safe as long as all callers run in the same event loop.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# user_label → set of asyncio.Queue
|
||||
self._clients: Dict[str, Set[asyncio.Queue]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Subscription lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, user_label: str) -> asyncio.Queue:
|
||||
"""Register a new client connection. Returns the queue to read from."""
|
||||
q: asyncio.Queue = asyncio.Queue(maxsize=256)
|
||||
self._clients.setdefault(user_label, set()).add(q)
|
||||
logger.debug("WSBus: %s subscribed (%d queues)", user_label,
|
||||
len(self._clients[user_label]))
|
||||
return q
|
||||
|
||||
def unsubscribe(self, user_label: str, queue: asyncio.Queue) -> None:
|
||||
"""Remove a client connection."""
|
||||
queues = self._clients.get(user_label, set())
|
||||
queues.discard(queue)
|
||||
if not queues:
|
||||
self._clients.pop(user_label, None)
|
||||
logger.debug("WSBus: %s unsubscribed", user_label)
|
||||
|
||||
@property
|
||||
def connected_users(self) -> list[str]:
|
||||
"""List of user labels with at least one active connection."""
|
||||
return list(self._clients.keys())
|
||||
|
||||
@property
|
||||
def total_connections(self) -> int:
|
||||
return sum(len(qs) for qs in self._clients.values())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Broadcasting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _frame(self, event_type: str, data: Any) -> str:
|
||||
return json.dumps({"type": event_type, "data": data, "ts": time.time()})
|
||||
|
||||
async def broadcast(self, event_type: str, data: Any) -> None:
|
||||
"""Send an event to ALL connected clients."""
|
||||
frame = self._frame(event_type, data)
|
||||
for queues in list(self._clients.values()):
|
||||
for q in list(queues):
|
||||
try:
|
||||
q.put_nowait(frame)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("WSBus: queue full, dropping %s event", event_type)
|
||||
|
||||
async def broadcast_to_user(
|
||||
self, user_label: str, event_type: str, data: Any
|
||||
) -> None:
|
||||
"""Send an event to all connections belonging to *user_label*."""
|
||||
queues = self._clients.get(user_label, set())
|
||||
if not queues:
|
||||
logger.debug("WSBus: no clients for user '%s', dropping %s", user_label, event_type)
|
||||
return
|
||||
frame = self._frame(event_type, data)
|
||||
for q in list(queues):
|
||||
try:
|
||||
q.put_nowait(frame)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("WSBus: queue full for %s, dropping %s", user_label, event_type)
|
||||
|
||||
|
||||
# Module-level singleton (set by web/app.py)
|
||||
_bus: WSBus | None = None
|
||||
|
||||
|
||||
def get_bus() -> WSBus:
|
||||
global _bus
|
||||
if _bus is None:
|
||||
_bus = WSBus()
|
||||
return _bus
|
||||
Reference in New Issue
Block a user