In Starlette 0.52+, Mount('/') returns Match.FULL for every WebSocket
scope. If APIWebSocketRoute('/ws') is somehow not matched first, the
StaticFiles mount catches the connection and crashes with:
assert scope["type"] == "http" # AssertionError
Two-layer fix:
- _SPAStaticFiles.__call__: gracefully close non-HTTP connections with
WebSocketClose() and log a warning with the actual path/type so the
routing issue can be diagnosed.
- app.add_websocket_route('/ws', websocket_endpoint): belt-and-suspenders
registration using Starlette's base WebSocketRoute (simpler than
FastAPI's APIWebSocketRoute) right before the StaticFiles mount. If
include_router's APIWebSocketRoute doesn't match, this fallback will.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
298 lines
12 KiB
Python
298 lines
12 KiB
Python
"""
|
|
web/app.py
|
|
==========
|
|
|
|
FastAPI application factory.
|
|
|
|
The app is created once and shared between the Uvicorn server (started
|
|
from bot.py via asyncio.gather) and tests.
|
|
|
|
Startup tasks:
|
|
- Background status ticker (broadcasts status_snapshot every 5s to all clients)
|
|
- Background NSSM poll (broadcasts server_state every 10s to all clients)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import mimetypes
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI
|
|
from starlette.exceptions import HTTPException as _HTTPException
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request as _Request
|
|
|
|
# Windows registry can map .js → text/plain; override to the correct types
|
|
# before StaticFiles reads them.
|
|
mimetypes.add_type("application/javascript", ".js")
|
|
mimetypes.add_type("application/javascript", ".mjs")
|
|
mimetypes.add_type("text/css", ".css")
|
|
mimetypes.add_type("application/wasm", ".wasm")
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
|
class _NoCacheHTMLMiddleware(BaseHTTPMiddleware):
|
|
"""Force browsers to revalidate index.html on every request.
|
|
|
|
Vite hashes JS/CSS filenames on every build so those assets are
|
|
naturally cache-busted. index.html itself has a stable name, so
|
|
without an explicit Cache-Control header mobile browsers apply
|
|
heuristic caching and keep serving a stale copy after a redeploy.
|
|
"""
|
|
|
|
async def dispatch(self, request: _Request, call_next):
|
|
response = await call_next(request)
|
|
ct = response.headers.get("content-type", "")
|
|
if "text/html" in ct:
|
|
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
|
return response
|
|
|
|
|
|
class _SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
"""Add security headers to every response."""
|
|
|
|
_CSP = (
|
|
"default-src 'self'; "
|
|
"script-src 'self' 'unsafe-inline'; "
|
|
"style-src 'self' 'unsafe-inline'; "
|
|
"img-src 'self' data: blob:; "
|
|
"connect-src 'self' wss:; "
|
|
"frame-ancestors 'none';"
|
|
)
|
|
|
|
async def dispatch(self, request: _Request, call_next):
|
|
response = await call_next(request)
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
|
response.headers["Content-Security-Policy"] = self._CSP
|
|
return response
|
|
|
|
class _SPAStaticFiles(StaticFiles):
|
|
"""StaticFiles with SPA fallback: serve index.html for unknown paths.
|
|
|
|
Starlette's html=True only serves index.html for directory requests.
|
|
This subclass additionally returns index.html for any path that has no
|
|
matching file, so client-side routes like /generate work on refresh.
|
|
"""
|
|
|
|
async def __call__(self, scope, receive, send) -> None:
|
|
# StaticFiles only handles HTTP. In Starlette 0.52+, Mount('/') returns
|
|
# Match.FULL for ALL WebSocket scopes, so a WebSocket connection can
|
|
# reach here if the dedicated /ws route somehow doesn't match first.
|
|
# Close gracefully instead of asserting.
|
|
if scope.get("type") != "http":
|
|
from starlette.websockets import WebSocketClose
|
|
logger.warning(
|
|
"WebSocket or non-HTTP scope reached StaticFiles "
|
|
"(path=%r, type=%r) — closing gracefully. "
|
|
"This indicates a routing issue; /ws route did not match.",
|
|
scope.get("path"),
|
|
scope.get("type"),
|
|
)
|
|
await WebSocketClose()(scope, receive, send)
|
|
return
|
|
await super().__call__(scope, receive, send)
|
|
|
|
async def get_response(self, path: str, scope):
|
|
try:
|
|
return await super().get_response(path, scope)
|
|
except _HTTPException as ex:
|
|
if ex.status_code == 404:
|
|
return await super().get_response("index.html", scope)
|
|
raise
|
|
|
|
|
|
from web.ws_bus import get_bus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
_WEB_STATIC = _PROJECT_ROOT / "web-static"
|
|
|
|
|
|
def create_app() -> FastAPI:
|
|
"""Create and configure the FastAPI application."""
|
|
app = FastAPI(
|
|
title="ComfyUI Bot Web UI",
|
|
version="1.0.0",
|
|
docs_url=None,
|
|
redoc_url=None,
|
|
openapi_url=None,
|
|
)
|
|
|
|
# CORS — only allow explicitly configured origins; empty = no cross-origin
|
|
_cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()]
|
|
if _cors_origins:
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=_cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Security headers on every response
|
|
app.add_middleware(_SecurityHeadersMiddleware)
|
|
|
|
# Prevent browsers from caching index.html across deploys
|
|
app.add_middleware(_NoCacheHTMLMiddleware)
|
|
|
|
# Register API routers
|
|
from web.routers.auth_router import router as auth_router
|
|
from web.routers.admin_router import router as admin_router
|
|
from web.routers.status_router import router as status_router
|
|
from web.routers.state_router import router as state_router
|
|
from web.routers.generate_router import router as generate_router
|
|
from web.routers.inputs_router import router as inputs_router
|
|
from web.routers.presets_router import router as presets_router
|
|
from web.routers.server_router import router as server_router
|
|
from web.routers.history_router import router as history_router
|
|
from web.routers.share_router import router as share_router
|
|
from web.routers.workflow_router import router as workflow_router
|
|
from web.routers.ws_router import router as ws_router
|
|
|
|
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
|
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
|
|
app.include_router(status_router, prefix="/api", tags=["status"])
|
|
app.include_router(state_router, prefix="/api", tags=["state"])
|
|
app.include_router(generate_router, prefix="/api", tags=["generate"])
|
|
app.include_router(inputs_router, prefix="/api/inputs", tags=["inputs"])
|
|
app.include_router(presets_router, prefix="/api/presets", tags=["presets"])
|
|
app.include_router(server_router, prefix="/api", tags=["server"])
|
|
app.include_router(history_router, prefix="/api/history", tags=["history"])
|
|
app.include_router(share_router, prefix="/api/share", tags=["share"])
|
|
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
|
|
app.include_router(ws_router, tags=["ws"])
|
|
|
|
# Belt-and-suspenders: register /ws directly on the app so it sits at the
|
|
# top of the route list and is guaranteed to be checked before the
|
|
# StaticFiles catch-all mount below. In Starlette 0.52+, Mount('/')
|
|
# returns Match.FULL for every WebSocket scope, which means if the route
|
|
# from include_router is somehow not matched first, the mount wins and
|
|
# StaticFiles crashes with AssertionError. Having the route registered
|
|
# twice is harmless — the first Match.FULL in the route list wins.
|
|
from web.routers.ws_router import websocket_endpoint as _ws_endpoint
|
|
app.add_websocket_route("/ws", _ws_endpoint)
|
|
|
|
# Serve frontend static files (if built)
|
|
if _WEB_STATIC.exists() and any(_WEB_STATIC.iterdir()):
|
|
app.mount("/", _SPAStaticFiles(directory=str(_WEB_STATIC), html=True), name="static")
|
|
logger.info("Serving frontend from %s", _WEB_STATIC)
|
|
|
|
@app.on_event("startup")
|
|
async def _startup():
|
|
asyncio.create_task(_status_ticker())
|
|
asyncio.create_task(_server_state_poller())
|
|
logger.info("Web background tasks started")
|
|
|
|
return app
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Background tasks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def _status_ticker() -> None:
|
|
"""Broadcast status_snapshot to all clients every 5 seconds."""
|
|
from web.deps import get_bot, get_comfy, get_config
|
|
bus = get_bus()
|
|
|
|
while True:
|
|
await asyncio.sleep(5)
|
|
try:
|
|
bot = get_bot()
|
|
comfy = get_comfy()
|
|
config = get_config()
|
|
|
|
snapshot: dict = {}
|
|
|
|
if bot is not None:
|
|
lat = bot.latency
|
|
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
|
|
import datetime as _dt
|
|
start = getattr(bot, "start_time", None)
|
|
uptime = ""
|
|
if start:
|
|
delta = _dt.datetime.now(_dt.timezone.utc) - start
|
|
total = int(delta.total_seconds())
|
|
h, rem = divmod(total, 3600)
|
|
m, s = divmod(rem, 60)
|
|
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
|
|
snapshot["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
|
|
|
|
if comfy is not None:
|
|
q = await comfy.get_comfy_queue()
|
|
pending = len(q.get("queue_pending", [])) if q else 0
|
|
running = len(q.get("queue_running", [])) if q else 0
|
|
wm = getattr(comfy, "workflow_manager", None)
|
|
wf_loaded = wm is not None and wm.get_workflow_template() is not None
|
|
snapshot["comfy"] = {
|
|
"server": comfy.server_address,
|
|
"queue_pending": pending,
|
|
"queue_running": running,
|
|
"workflow_loaded": wf_loaded,
|
|
"last_seed": comfy.last_seed,
|
|
"total_generated": comfy.total_generated,
|
|
}
|
|
|
|
if config is not None:
|
|
from media_uploader import get_stats as get_upload_stats, is_running as upload_running
|
|
try:
|
|
us = get_upload_stats()
|
|
snapshot["upload"] = {
|
|
"configured": bool(config.media_upload_user),
|
|
"running": upload_running(),
|
|
"total_ok": us.total_ok,
|
|
"total_fail": us.total_fail,
|
|
}
|
|
except Exception:
|
|
pass
|
|
|
|
from web.deps import get_user_registry
|
|
registry = get_user_registry()
|
|
connected = bus.connected_users
|
|
if connected and registry:
|
|
for ul in connected:
|
|
user_overrides = registry.get_state_manager(ul).get_overrides()
|
|
await bus.broadcast_to_user(ul, "status_snapshot", {**snapshot, "overrides": user_overrides})
|
|
else:
|
|
await bus.broadcast("status_snapshot", snapshot)
|
|
except Exception as exc:
|
|
logger.debug("Status ticker error: %s", exc)
|
|
|
|
|
|
async def _server_state_poller() -> None:
|
|
"""Poll NSSM service state and broadcast server_state every 10 seconds."""
|
|
from web.deps import get_config
|
|
bus = get_bus()
|
|
|
|
while True:
|
|
await asyncio.sleep(10)
|
|
try:
|
|
config = get_config()
|
|
if config is None:
|
|
continue
|
|
from commands.server import get_service_state
|
|
from web.deps import get_comfy
|
|
|
|
async def _false():
|
|
return False
|
|
|
|
comfy = get_comfy()
|
|
service_state, http_reachable = await asyncio.gather(
|
|
get_service_state(config.comfy_service_name),
|
|
comfy.check_connection() if comfy else _false(),
|
|
)
|
|
await bus.broadcast("server_state", {
|
|
"state": service_state,
|
|
"http_reachable": http_reachable,
|
|
})
|
|
except Exception as exc:
|
|
logger.debug("Server state poller error: %s", exc)
|