Files
comfy-discord-web/web/app.py
Khoa (Revenovich) Tran Gia e83703aff0 fix: add ASGI middleware to prevent WS reaching StaticFiles
_WSInterceptMiddleware intercepts /ws WebSocket scopes at the outermost
ASGI layer before Starlette routing is consulted, so Mount('/') can never
hand a WS connection to _SPAStaticFiles regardless of route ordering.

Also downgrade the _SPAStaticFiles non-HTTP fallback log from WARNING to
DEBUG — the graceful close still fires as a safety net, but no longer
spams the log since the middleware handles the normal case.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 11:19:30 +07:00

319 lines
12 KiB
Python

"""
web/app.py
==========
FastAPI application factory.
The app is created once and shared between the Uvicorn server (started
from bot.py via asyncio.gather) and tests.
Startup tasks:
- Background status ticker (broadcasts status_snapshot every 5s to all clients)
- Background NSSM poll (broadcasts server_state every 10s to all clients)
"""
from __future__ import annotations
import asyncio
import logging
import mimetypes
import os
from pathlib import Path
from fastapi import FastAPI
from starlette.exceptions import HTTPException as _HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request as _Request
# Windows registry can map .js → text/plain; override to the correct types
# before StaticFiles reads them.
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/wasm", ".wasm")
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
class _NoCacheHTMLMiddleware(BaseHTTPMiddleware):
"""Force browsers to revalidate index.html on every request.
Vite hashes JS/CSS filenames on every build so those assets are
naturally cache-busted. index.html itself has a stable name, so
without an explicit Cache-Control header mobile browsers apply
heuristic caching and keep serving a stale copy after a redeploy.
"""
async def dispatch(self, request: _Request, call_next):
response = await call_next(request)
ct = response.headers.get("content-type", "")
if "text/html" in ct:
response.headers["Cache-Control"] = "no-cache, must-revalidate"
return response
class _SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to every response."""
_CSP = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: blob:; "
"connect-src 'self' wss:; "
"frame-ancestors 'none';"
)
async def dispatch(self, request: _Request, call_next):
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Content-Security-Policy"] = self._CSP
return response
class _WSInterceptMiddleware:
"""Short-circuit /ws WebSocket connections before Starlette routing.
In Starlette 0.52+, Mount('/') returns Match.FULL for every WebSocket
scope, so a WS connection can reach _SPAStaticFiles if the dedicated /ws
WebSocketRoute is somehow not matched first. Wrapping the whole ASGI app
here guarantees /ws is handled before any route or mount is consulted.
"""
def __init__(self, app, ws_handler) -> None:
self._app = app
self._ws = ws_handler
async def __call__(self, scope, receive, send) -> None:
if scope["type"] == "websocket" and scope.get("path") == "/ws":
await self._ws(scope, receive, send)
else:
await self._app(scope, receive, send)
class _SPAStaticFiles(StaticFiles):
"""StaticFiles with SPA fallback: serve index.html for unknown paths.
Starlette's html=True only serves index.html for directory requests.
This subclass additionally returns index.html for any path that has no
matching file, so client-side routes like /generate work on refresh.
"""
async def __call__(self, scope, receive, send) -> None:
# StaticFiles only handles HTTP. In Starlette 0.52+, Mount('/') returns
# Match.FULL for ALL WebSocket scopes; _WSInterceptMiddleware should
# prevent any WS connection from reaching here, but close gracefully
# as a last-resort safety net rather than raising AssertionError.
if scope.get("type") != "http":
from starlette.websockets import WebSocketClose
logger.debug(
"non-HTTP scope reached StaticFiles (path=%r, type=%r) — closing gracefully",
scope.get("path"),
scope.get("type"),
)
await WebSocketClose()(scope, receive, send)
return
await super().__call__(scope, receive, send)
async def get_response(self, path: str, scope):
try:
return await super().get_response(path, scope)
except _HTTPException as ex:
if ex.status_code == 404:
return await super().get_response("index.html", scope)
raise
from web.ws_bus import get_bus
logger = logging.getLogger(__name__)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
_WEB_STATIC = _PROJECT_ROOT / "web-static"
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
app = FastAPI(
title="ComfyUI Bot Web UI",
version="1.0.0",
docs_url=None,
redoc_url=None,
openapi_url=None,
)
# CORS — only allow explicitly configured origins; empty = no cross-origin
_cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()]
if _cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security headers on every response
app.add_middleware(_SecurityHeadersMiddleware)
# Prevent browsers from caching index.html across deploys
app.add_middleware(_NoCacheHTMLMiddleware)
# Register API routers
from web.routers.auth_router import router as auth_router
from web.routers.admin_router import router as admin_router
from web.routers.status_router import router as status_router
from web.routers.state_router import router as state_router
from web.routers.generate_router import router as generate_router
from web.routers.inputs_router import router as inputs_router
from web.routers.presets_router import router as presets_router
from web.routers.server_router import router as server_router
from web.routers.history_router import router as history_router
from web.routers.share_router import router as share_router
from web.routers.workflow_router import router as workflow_router
from web.routers.ws_router import router as ws_router
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
app.include_router(status_router, prefix="/api", tags=["status"])
app.include_router(state_router, prefix="/api", tags=["state"])
app.include_router(generate_router, prefix="/api", tags=["generate"])
app.include_router(inputs_router, prefix="/api/inputs", tags=["inputs"])
app.include_router(presets_router, prefix="/api/presets", tags=["presets"])
app.include_router(server_router, prefix="/api", tags=["server"])
app.include_router(history_router, prefix="/api/history", tags=["history"])
app.include_router(share_router, prefix="/api/share", tags=["share"])
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
app.include_router(ws_router, tags=["ws"])
from web.routers.ws_router import websocket_endpoint as _ws_endpoint
# Belt-and-suspenders: register /ws directly on the app so it sits at the
# top of the route list and is guaranteed to be checked before the
# StaticFiles catch-all mount below.
app.add_websocket_route("/ws", _ws_endpoint)
# Serve frontend static files (if built)
if _WEB_STATIC.exists() and any(_WEB_STATIC.iterdir()):
app.mount("/", _SPAStaticFiles(directory=str(_WEB_STATIC), html=True), name="static")
logger.info("Serving frontend from %s", _WEB_STATIC)
# Wrap the entire ASGI app so /ws WebSocket connections are intercepted
# before Starlette routing can hand them to the StaticFiles Mount('/').
# add_middleware places this as the outermost layer — it runs before any
# route or mount is consulted.
app.add_middleware(_WSInterceptMiddleware, ws_handler=_ws_endpoint)
@app.on_event("startup")
async def _startup():
asyncio.create_task(_status_ticker())
asyncio.create_task(_server_state_poller())
logger.info("Web background tasks started")
return app
# ---------------------------------------------------------------------------
# Background tasks
# ---------------------------------------------------------------------------
async def _status_ticker() -> None:
"""Broadcast status_snapshot to all clients every 5 seconds."""
from web.deps import get_bot, get_comfy, get_config
bus = get_bus()
while True:
await asyncio.sleep(5)
try:
bot = get_bot()
comfy = get_comfy()
config = get_config()
snapshot: dict = {}
if bot is not None:
lat = bot.latency
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
import datetime as _dt
start = getattr(bot, "start_time", None)
uptime = ""
if start:
delta = _dt.datetime.now(_dt.timezone.utc) - start
total = int(delta.total_seconds())
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
snapshot["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
if comfy is not None:
q = await comfy.get_comfy_queue()
pending = len(q.get("queue_pending", [])) if q else 0
running = len(q.get("queue_running", [])) if q else 0
wm = getattr(comfy, "workflow_manager", None)
wf_loaded = wm is not None and wm.get_workflow_template() is not None
snapshot["comfy"] = {
"server": comfy.server_address,
"queue_pending": pending,
"queue_running": running,
"workflow_loaded": wf_loaded,
"last_seed": comfy.last_seed,
"total_generated": comfy.total_generated,
}
if config is not None:
from media_uploader import get_stats as get_upload_stats, is_running as upload_running
try:
us = get_upload_stats()
snapshot["upload"] = {
"configured": bool(config.media_upload_user),
"running": upload_running(),
"total_ok": us.total_ok,
"total_fail": us.total_fail,
}
except Exception:
pass
from web.deps import get_user_registry
registry = get_user_registry()
connected = bus.connected_users
if connected and registry:
for ul in connected:
user_overrides = registry.get_state_manager(ul).get_overrides()
await bus.broadcast_to_user(ul, "status_snapshot", {**snapshot, "overrides": user_overrides})
else:
await bus.broadcast("status_snapshot", snapshot)
except Exception as exc:
logger.debug("Status ticker error: %s", exc)
async def _server_state_poller() -> None:
"""Poll NSSM service state and broadcast server_state every 10 seconds."""
from web.deps import get_config
bus = get_bus()
while True:
await asyncio.sleep(10)
try:
config = get_config()
if config is None:
continue
from commands.server import get_service_state
from web.deps import get_comfy
async def _false():
return False
comfy = get_comfy()
service_state, http_reachable = await asyncio.gather(
get_service_state(config.comfy_service_name),
comfy.check_connection() if comfy else _false(),
)
await bus.broadcast("server_state", {
"state": service_state,
"http_reachable": http_reachable,
})
except Exception as exc:
logger.debug("Server state poller error: %s", exc)