Initial commit — ComfyUI Discord bot + web UI
Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI (React/TS/Vite/Tailwind), ComfyUI integration, generation history DB, preset manager, workflow inspector, and all supporting modules. Excluded from tracking: .env, invite_tokens.json, *.db (SQLite), current-workflow-changes.json, user_settings/, presets/, logs/, web-static/ (build output), frontend/node_modules/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
269
web/app.py
Normal file
269
web/app.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
web/app.py
|
||||
==========
|
||||
|
||||
FastAPI application factory.
|
||||
|
||||
The app is created once and shared between the Uvicorn server (started
|
||||
from bot.py via asyncio.gather) and tests.
|
||||
|
||||
Startup tasks:
|
||||
- Background status ticker (broadcasts status_snapshot every 5s to all clients)
|
||||
- Background NSSM poll (broadcasts server_state every 10s to all clients)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from starlette.exceptions import HTTPException as _HTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request as _Request
|
||||
|
||||
# Windows registry can map .js → text/plain; override to the correct types
|
||||
# before StaticFiles reads them.
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("application/javascript", ".mjs")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
mimetypes.add_type("application/wasm", ".wasm")
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
|
||||
class _NoCacheHTMLMiddleware(BaseHTTPMiddleware):
|
||||
"""Force browsers to revalidate index.html on every request.
|
||||
|
||||
Vite hashes JS/CSS filenames on every build so those assets are
|
||||
naturally cache-busted. index.html itself has a stable name, so
|
||||
without an explicit Cache-Control header mobile browsers apply
|
||||
heuristic caching and keep serving a stale copy after a redeploy.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: _Request, call_next):
|
||||
response = await call_next(request)
|
||||
ct = response.headers.get("content-type", "")
|
||||
if "text/html" in ct:
|
||||
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
||||
return response
|
||||
|
||||
|
||||
class _SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to every response."""
|
||||
|
||||
_CSP = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: blob:; "
|
||||
"connect-src 'self' wss:; "
|
||||
"frame-ancestors 'none';"
|
||||
)
|
||||
|
||||
async def dispatch(self, request: _Request, call_next):
|
||||
response = await call_next(request)
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Content-Security-Policy"] = self._CSP
|
||||
return response
|
||||
|
||||
class _SPAStaticFiles(StaticFiles):
|
||||
"""StaticFiles with SPA fallback: serve index.html for unknown paths.
|
||||
|
||||
Starlette's html=True only serves index.html for directory requests.
|
||||
This subclass additionally returns index.html for any path that has no
|
||||
matching file, so client-side routes like /generate work on refresh.
|
||||
"""
|
||||
|
||||
async def get_response(self, path: str, scope):
|
||||
try:
|
||||
return await super().get_response(path, scope)
|
||||
except _HTTPException as ex:
|
||||
if ex.status_code == 404:
|
||||
return await super().get_response("index.html", scope)
|
||||
raise
|
||||
|
||||
|
||||
from web.ws_bus import get_bus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
_WEB_STATIC = _PROJECT_ROOT / "web-static"
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
app = FastAPI(
|
||||
title="ComfyUI Bot Web UI",
|
||||
version="1.0.0",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
openapi_url=None,
|
||||
)
|
||||
|
||||
# CORS — only allow explicitly configured origins; empty = no cross-origin
|
||||
_cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()]
|
||||
if _cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Security headers on every response
|
||||
app.add_middleware(_SecurityHeadersMiddleware)
|
||||
|
||||
# Prevent browsers from caching index.html across deploys
|
||||
app.add_middleware(_NoCacheHTMLMiddleware)
|
||||
|
||||
# Register API routers
|
||||
from web.routers.auth_router import router as auth_router
|
||||
from web.routers.admin_router import router as admin_router
|
||||
from web.routers.status_router import router as status_router
|
||||
from web.routers.state_router import router as state_router
|
||||
from web.routers.generate_router import router as generate_router
|
||||
from web.routers.inputs_router import router as inputs_router
|
||||
from web.routers.presets_router import router as presets_router
|
||||
from web.routers.server_router import router as server_router
|
||||
from web.routers.history_router import router as history_router
|
||||
from web.routers.share_router import router as share_router
|
||||
from web.routers.workflow_router import router as workflow_router
|
||||
from web.routers.ws_router import router as ws_router
|
||||
|
||||
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
||||
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
|
||||
app.include_router(status_router, prefix="/api", tags=["status"])
|
||||
app.include_router(state_router, prefix="/api", tags=["state"])
|
||||
app.include_router(generate_router, prefix="/api", tags=["generate"])
|
||||
app.include_router(inputs_router, prefix="/api/inputs", tags=["inputs"])
|
||||
app.include_router(presets_router, prefix="/api/presets", tags=["presets"])
|
||||
app.include_router(server_router, prefix="/api", tags=["server"])
|
||||
app.include_router(history_router, prefix="/api/history", tags=["history"])
|
||||
app.include_router(share_router, prefix="/api/share", tags=["share"])
|
||||
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
|
||||
app.include_router(ws_router, tags=["ws"])
|
||||
|
||||
# Serve frontend static files (if built)
|
||||
if _WEB_STATIC.exists() and any(_WEB_STATIC.iterdir()):
|
||||
app.mount("/", _SPAStaticFiles(directory=str(_WEB_STATIC), html=True), name="static")
|
||||
logger.info("Serving frontend from %s", _WEB_STATIC)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup():
|
||||
asyncio.create_task(_status_ticker())
|
||||
asyncio.create_task(_server_state_poller())
|
||||
logger.info("Web background tasks started")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Background tasks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _status_ticker() -> None:
|
||||
"""Broadcast status_snapshot to all clients every 5 seconds."""
|
||||
from web.deps import get_bot, get_comfy, get_config
|
||||
bus = get_bus()
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
try:
|
||||
bot = get_bot()
|
||||
comfy = get_comfy()
|
||||
config = get_config()
|
||||
|
||||
snapshot: dict = {}
|
||||
|
||||
if bot is not None:
|
||||
lat = bot.latency
|
||||
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
|
||||
import datetime as _dt
|
||||
start = getattr(bot, "start_time", None)
|
||||
uptime = ""
|
||||
if start:
|
||||
delta = _dt.datetime.now(_dt.timezone.utc) - start
|
||||
total = int(delta.total_seconds())
|
||||
h, rem = divmod(total, 3600)
|
||||
m, s = divmod(rem, 60)
|
||||
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
|
||||
snapshot["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
|
||||
|
||||
if comfy is not None:
|
||||
q = await comfy.get_comfy_queue()
|
||||
pending = len(q.get("queue_pending", [])) if q else 0
|
||||
running = len(q.get("queue_running", [])) if q else 0
|
||||
wm = getattr(comfy, "workflow_manager", None)
|
||||
wf_loaded = wm is not None and wm.get_workflow_template() is not None
|
||||
snapshot["comfy"] = {
|
||||
"server": comfy.server_address,
|
||||
"queue_pending": pending,
|
||||
"queue_running": running,
|
||||
"workflow_loaded": wf_loaded,
|
||||
"last_seed": comfy.last_seed,
|
||||
"total_generated": comfy.total_generated,
|
||||
}
|
||||
|
||||
if config is not None:
|
||||
from media_uploader import get_stats as get_upload_stats, is_running as upload_running
|
||||
try:
|
||||
us = get_upload_stats()
|
||||
snapshot["upload"] = {
|
||||
"configured": bool(config.media_upload_user),
|
||||
"running": upload_running(),
|
||||
"total_ok": us.total_ok,
|
||||
"total_fail": us.total_fail,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from web.deps import get_user_registry
|
||||
registry = get_user_registry()
|
||||
connected = bus.connected_users
|
||||
if connected and registry:
|
||||
for ul in connected:
|
||||
user_overrides = registry.get_state_manager(ul).get_overrides()
|
||||
await bus.broadcast_to_user(ul, "status_snapshot", {**snapshot, "overrides": user_overrides})
|
||||
else:
|
||||
await bus.broadcast("status_snapshot", snapshot)
|
||||
except Exception as exc:
|
||||
logger.debug("Status ticker error: %s", exc)
|
||||
|
||||
|
||||
async def _server_state_poller() -> None:
|
||||
"""Poll NSSM service state and broadcast server_state every 10 seconds."""
|
||||
from web.deps import get_config
|
||||
bus = get_bus()
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
try:
|
||||
config = get_config()
|
||||
if config is None:
|
||||
continue
|
||||
from commands.server import get_service_state
|
||||
from web.deps import get_comfy
|
||||
|
||||
async def _false():
|
||||
return False
|
||||
|
||||
comfy = get_comfy()
|
||||
service_state, http_reachable = await asyncio.gather(
|
||||
get_service_state(config.comfy_service_name),
|
||||
comfy.check_connection() if comfy else _false(),
|
||||
)
|
||||
await bus.broadcast("server_state", {
|
||||
"state": service_state,
|
||||
"http_reachable": http_reachable,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.debug("Server state poller error: %s", exc)
|
||||
Reference in New Issue
Block a user