""" web/app.py ========== FastAPI application factory. The app is created once and shared between the Uvicorn server (started from bot.py via asyncio.gather) and tests. Startup tasks: - Background status ticker (broadcasts status_snapshot every 5s to all clients) - Background NSSM poll (broadcasts server_state every 10s to all clients) """ from __future__ import annotations import asyncio import logging import mimetypes import os from pathlib import Path from fastapi import FastAPI from starlette.exceptions import HTTPException as _HTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request as _Request # Windows registry can map .js → text/plain; override to the correct types # before StaticFiles reads them. mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("application/javascript", ".mjs") mimetypes.add_type("text/css", ".css") mimetypes.add_type("application/wasm", ".wasm") from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles class _NoCacheHTMLMiddleware(BaseHTTPMiddleware): """Force browsers to revalidate index.html on every request. Vite hashes JS/CSS filenames on every build so those assets are naturally cache-busted. index.html itself has a stable name, so without an explicit Cache-Control header mobile browsers apply heuristic caching and keep serving a stale copy after a redeploy. """ async def dispatch(self, request: _Request, call_next): response = await call_next(request) ct = response.headers.get("content-type", "") if "text/html" in ct: response.headers["Cache-Control"] = "no-cache, must-revalidate" return response class _SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add security headers to every response.""" _CSP = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: blob:; " "connect-src 'self' wss:; " "frame-ancestors 'none';" ) async def dispatch(self, request: _Request, call_next): response = await call_next(request) response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" response.headers["Content-Security-Policy"] = self._CSP return response class _SPAStaticFiles(StaticFiles): """StaticFiles with SPA fallback: serve index.html for unknown paths. Starlette's html=True only serves index.html for directory requests. This subclass additionally returns index.html for any path that has no matching file, so client-side routes like /generate work on refresh. """ async def __call__(self, scope, receive, send) -> None: # StaticFiles only handles HTTP. In Starlette 0.52+, Mount('/') returns # Match.FULL for ALL WebSocket scopes, so a WebSocket connection can # reach here if the dedicated /ws route somehow doesn't match first. # Close gracefully instead of asserting. if scope.get("type") != "http": from starlette.websockets import WebSocketClose logger.warning( "WebSocket or non-HTTP scope reached StaticFiles " "(path=%r, type=%r) — closing gracefully. " "This indicates a routing issue; /ws route did not match.", scope.get("path"), scope.get("type"), ) await WebSocketClose()(scope, receive, send) return await super().__call__(scope, receive, send) async def get_response(self, path: str, scope): try: return await super().get_response(path, scope) except _HTTPException as ex: if ex.status_code == 404: return await super().get_response("index.html", scope) raise from web.ws_bus import get_bus logger = logging.getLogger(__name__) _PROJECT_ROOT = Path(__file__).resolve().parent.parent _WEB_STATIC = _PROJECT_ROOT / "web-static" def create_app() -> FastAPI: """Create and configure the FastAPI application.""" app = FastAPI( title="ComfyUI Bot Web UI", version="1.0.0", docs_url=None, redoc_url=None, openapi_url=None, ) # CORS — only allow explicitly configured origins; empty = no cross-origin _cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()] if _cors_origins: app.add_middleware( CORSMiddleware, allow_origins=_cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Security headers on every response app.add_middleware(_SecurityHeadersMiddleware) # Prevent browsers from caching index.html across deploys app.add_middleware(_NoCacheHTMLMiddleware) # Register API routers from web.routers.auth_router import router as auth_router from web.routers.admin_router import router as admin_router from web.routers.status_router import router as status_router from web.routers.state_router import router as state_router from web.routers.generate_router import router as generate_router from web.routers.inputs_router import router as inputs_router from web.routers.presets_router import router as presets_router from web.routers.server_router import router as server_router from web.routers.history_router import router as history_router from web.routers.share_router import router as share_router from web.routers.workflow_router import router as workflow_router from web.routers.ws_router import router as ws_router app.include_router(auth_router, prefix="/api/auth", tags=["auth"]) app.include_router(admin_router, prefix="/api/admin", tags=["admin"]) app.include_router(status_router, prefix="/api", tags=["status"]) app.include_router(state_router, prefix="/api", tags=["state"]) app.include_router(generate_router, prefix="/api", tags=["generate"]) app.include_router(inputs_router, prefix="/api/inputs", tags=["inputs"]) app.include_router(presets_router, prefix="/api/presets", tags=["presets"]) app.include_router(server_router, prefix="/api", tags=["server"]) app.include_router(history_router, prefix="/api/history", tags=["history"]) app.include_router(share_router, prefix="/api/share", tags=["share"]) app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"]) app.include_router(ws_router, tags=["ws"]) # Belt-and-suspenders: register /ws directly on the app so it sits at the # top of the route list and is guaranteed to be checked before the # StaticFiles catch-all mount below. In Starlette 0.52+, Mount('/') # returns Match.FULL for every WebSocket scope, which means if the route # from include_router is somehow not matched first, the mount wins and # StaticFiles crashes with AssertionError. Having the route registered # twice is harmless — the first Match.FULL in the route list wins. from web.routers.ws_router import websocket_endpoint as _ws_endpoint app.add_websocket_route("/ws", _ws_endpoint) # Serve frontend static files (if built) if _WEB_STATIC.exists() and any(_WEB_STATIC.iterdir()): app.mount("/", _SPAStaticFiles(directory=str(_WEB_STATIC), html=True), name="static") logger.info("Serving frontend from %s", _WEB_STATIC) @app.on_event("startup") async def _startup(): asyncio.create_task(_status_ticker()) asyncio.create_task(_server_state_poller()) logger.info("Web background tasks started") return app # --------------------------------------------------------------------------- # Background tasks # --------------------------------------------------------------------------- async def _status_ticker() -> None: """Broadcast status_snapshot to all clients every 5 seconds.""" from web.deps import get_bot, get_comfy, get_config bus = get_bus() while True: await asyncio.sleep(5) try: bot = get_bot() comfy = get_comfy() config = get_config() snapshot: dict = {} if bot is not None: lat = bot.latency lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0 import datetime as _dt start = getattr(bot, "start_time", None) uptime = "" if start: delta = _dt.datetime.now(_dt.timezone.utc) - start total = int(delta.total_seconds()) h, rem = divmod(total, 3600) m, s = divmod(rem, 60) uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s") snapshot["bot"] = {"latency_ms": lat_ms, "uptime": uptime} if comfy is not None: q = await comfy.get_comfy_queue() pending = len(q.get("queue_pending", [])) if q else 0 running = len(q.get("queue_running", [])) if q else 0 wm = getattr(comfy, "workflow_manager", None) wf_loaded = wm is not None and wm.get_workflow_template() is not None snapshot["comfy"] = { "server": comfy.server_address, "queue_pending": pending, "queue_running": running, "workflow_loaded": wf_loaded, "last_seed": comfy.last_seed, "total_generated": comfy.total_generated, } if config is not None: from media_uploader import get_stats as get_upload_stats, is_running as upload_running try: us = get_upload_stats() snapshot["upload"] = { "configured": bool(config.media_upload_user), "running": upload_running(), "total_ok": us.total_ok, "total_fail": us.total_fail, } except Exception: pass from web.deps import get_user_registry registry = get_user_registry() connected = bus.connected_users if connected and registry: for ul in connected: user_overrides = registry.get_state_manager(ul).get_overrides() await bus.broadcast_to_user(ul, "status_snapshot", {**snapshot, "overrides": user_overrides}) else: await bus.broadcast("status_snapshot", snapshot) except Exception as exc: logger.debug("Status ticker error: %s", exc) async def _server_state_poller() -> None: """Poll NSSM service state and broadcast server_state every 10 seconds.""" from web.deps import get_config bus = get_bus() while True: await asyncio.sleep(10) try: config = get_config() if config is None: continue from commands.server import get_service_state from web.deps import get_comfy async def _false(): return False comfy = get_comfy() service_state, http_reachable = await asyncio.gather( get_service_state(config.comfy_service_name), comfy.check_connection() if comfy else _false(), ) await bus.broadcast("server_state", { "state": service_state, "http_reachable": http_reachable, }) except Exception as exc: logger.debug("Server state poller error: %s", exc)