Initial commit — ComfyUI Discord bot + web UI
Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI (React/TS/Vite/Tailwind), ComfyUI integration, generation history DB, preset manager, workflow inspector, and all supporting modules. Excluded from tracking: .env, invite_tokens.json, *.db (SQLite), current-workflow-changes.json, user_settings/, presets/, logs/, web-static/ (build output), frontend/node_modules/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
1
web/routers/__init__.py
Normal file
1
web/routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# web.routers package
|
||||
88
web/routers/admin_router.py
Normal file
88
web/routers/admin_router.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""POST /api/admin/login; GET/POST/DELETE /api/admin/tokens"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from web.auth import create_jwt, require_admin
|
||||
from web.deps import get_config
|
||||
from web.login_guard import get_guard, get_real_ip
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_COOKIE = "ttb_session"
|
||||
audit = logging.getLogger("audit")
|
||||
|
||||
|
||||
class AdminLoginRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class CreateTokenRequest(BaseModel):
|
||||
label: str
|
||||
admin: bool = False
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def admin_login(body: AdminLoginRequest, request: Request, response: Response):
|
||||
"""Admin password login → admin JWT cookie."""
|
||||
config = get_config()
|
||||
expected_pw = config.admin_password if config else None
|
||||
|
||||
ip = get_real_ip(request)
|
||||
get_guard().check(ip)
|
||||
|
||||
# Constant-time comparison to prevent timing attacks
|
||||
if not expected_pw or not hmac.compare_digest(body.password, expected_pw):
|
||||
get_guard().record_failure(ip)
|
||||
audit.info("admin.login ip=%s success=False", ip)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Wrong password")
|
||||
|
||||
get_guard().record_success(ip)
|
||||
audit.info("admin.login ip=%s success=True", ip)
|
||||
|
||||
expire_hours = config.web_jwt_expire_hours if config else 8
|
||||
jwt_token = create_jwt("admin", admin=True, expire_hours=expire_hours)
|
||||
response.set_cookie(
|
||||
_COOKIE, jwt_token,
|
||||
httponly=True, secure=True, samesite="strict",
|
||||
max_age=expire_hours * 3600,
|
||||
)
|
||||
return {"label": "admin", "admin": True}
|
||||
|
||||
|
||||
@router.get("/tokens")
|
||||
async def list_tokens(_: dict = Depends(require_admin)):
|
||||
"""List all invite tokens (hashes shown, labels safe)."""
|
||||
from token_store import list_tokens as _list
|
||||
config = get_config()
|
||||
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||
records = _list(token_file)
|
||||
# Don't return hashes to the UI
|
||||
return [{"id": r["id"], "label": r["label"], "admin": r.get("admin", False),
|
||||
"created_at": r.get("created_at")} for r in records]
|
||||
|
||||
|
||||
@router.post("/tokens")
|
||||
async def create_token(body: CreateTokenRequest, _: dict = Depends(require_admin)):
|
||||
"""Create a new invite token. Returns the plaintext token (shown once)."""
|
||||
from token_store import create_token as _create
|
||||
config = get_config()
|
||||
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||
plaintext = _create(body.label, token_file, admin=body.admin)
|
||||
return {"token": plaintext, "label": body.label, "admin": body.admin}
|
||||
|
||||
|
||||
@router.delete("/tokens/{token_id}")
|
||||
async def revoke_token(token_id: str, _: dict = Depends(require_admin)):
|
||||
"""Revoke an invite token by ID."""
|
||||
from token_store import revoke_token as _revoke
|
||||
config = get_config()
|
||||
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||
ok = _revoke(token_id, token_file)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
return {"ok": True}
|
||||
64
web/routers/auth_router.py
Normal file
64
web/routers/auth_router.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""POST /api/auth/login|logout; GET /api/auth/me"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from web.auth import create_jwt, require_auth
|
||||
from web.deps import get_config
|
||||
from web.login_guard import get_guard, get_real_ip
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_COOKIE = "ttb_session"
|
||||
audit = logging.getLogger("audit")
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(body: LoginRequest, request: Request, response: Response):
|
||||
"""Exchange an invite token for a JWT session cookie."""
|
||||
from token_store import verify_token
|
||||
config = get_config()
|
||||
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||
expire_hours = config.web_jwt_expire_hours if config else 8
|
||||
|
||||
ip = get_real_ip(request)
|
||||
get_guard().check(ip)
|
||||
|
||||
record = verify_token(body.token, token_file)
|
||||
if record is None:
|
||||
get_guard().record_failure(ip)
|
||||
audit.info("auth.login ip=%s success=False", ip)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
|
||||
label: str = record["label"]
|
||||
admin: bool = record.get("admin", False)
|
||||
get_guard().record_success(ip)
|
||||
audit.info("auth.login ip=%s success=True label=%s", ip, label)
|
||||
|
||||
jwt_token = create_jwt(label, admin=admin, expire_hours=expire_hours)
|
||||
response.set_cookie(
|
||||
_COOKIE, jwt_token,
|
||||
httponly=True, secure=True, samesite="strict",
|
||||
max_age=expire_hours * 3600,
|
||||
)
|
||||
return {"label": label, "admin": admin}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(response: Response):
|
||||
"""Clear the session cookie."""
|
||||
response.delete_cookie(_COOKIE)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def me(user: dict = Depends(require_auth)):
|
||||
"""Return current user info."""
|
||||
return {"label": user["sub"], "admin": user.get("admin", False)}
|
||||
255
web/routers/generate_router.py
Normal file
255
web/routers/generate_router.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""POST /api/generate and /api/workflow-gen"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_comfy, get_config, get_user_registry
|
||||
from web.ws_bus import get_bus
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
overrides: Optional[Dict[str, Any]] = None # extra per-request overrides
|
||||
|
||||
|
||||
class WorkflowGenRequest(BaseModel):
|
||||
count: int = 1
|
||||
overrides: Optional[Dict[str, Any]] = None # per-request overrides (merged with state)
|
||||
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
|
||||
"""Submit a prompt-based generation to ComfyUI."""
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI client not available")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
bus = get_bus()
|
||||
registry = get_user_registry()
|
||||
|
||||
# Temporary seed override from request
|
||||
if body.overrides and "seed" in body.overrides:
|
||||
seed_override = body.overrides["seed"]
|
||||
elif registry:
|
||||
seed_override = registry.get_state_manager(user_label).get_seed()
|
||||
else:
|
||||
seed_override = comfy.state_manager.get_seed()
|
||||
|
||||
overrides_for_gen = {"prompt": body.prompt}
|
||||
if body.negative_prompt:
|
||||
overrides_for_gen["negative_prompt"] = body.negative_prompt
|
||||
if seed_override is not None:
|
||||
overrides_for_gen["seed"] = seed_override
|
||||
|
||||
# Also apply any extra per-request overrides
|
||||
if body.overrides:
|
||||
overrides_for_gen.update(body.overrides)
|
||||
|
||||
# Get queue position estimate
|
||||
depth = await comfy.get_queue_depth()
|
||||
|
||||
# Start generation as background task so we can return the prompt_id immediately
|
||||
prompt_id_holder: list = []
|
||||
|
||||
async def _run():
|
||||
# Use the user's own workflow template
|
||||
if registry:
|
||||
template = registry.get_workflow_template(user_label)
|
||||
else:
|
||||
template = comfy.workflow_manager.get_workflow_template()
|
||||
if not template:
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": None, "error": "No workflow template loaded"
|
||||
})
|
||||
return
|
||||
import uuid
|
||||
pid = str(uuid.uuid4())
|
||||
prompt_id_holder.append(pid)
|
||||
|
||||
def on_progress(node, pid_):
|
||||
asyncio.create_task(bus.broadcast("node_executing", {
|
||||
"node": node, "prompt_id": pid_
|
||||
}))
|
||||
|
||||
workflow, applied = comfy.inspector.inject_overrides(template, overrides_for_gen)
|
||||
seed_used = applied.get("seed")
|
||||
comfy.last_seed = seed_used
|
||||
|
||||
try:
|
||||
images, videos = await comfy._general_generate(workflow, pid, on_progress)
|
||||
except Exception as exc:
|
||||
logger.exception("Generation error for prompt %s", pid)
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": pid, "error": str(exc)
|
||||
})
|
||||
return
|
||||
|
||||
comfy.last_prompt_id = pid
|
||||
comfy.total_generated += 1
|
||||
|
||||
# Persist to DB before flush_pending deletes local files
|
||||
config = get_config()
|
||||
try:
|
||||
from generation_db import record_generation, record_file
|
||||
gen_id = record_generation(pid, "web", user_label, overrides_for_gen, seed_used)
|
||||
for i, img_data in enumerate(images):
|
||||
record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||
if config and videos:
|
||||
for vid in videos:
|
||||
vsub = vid.get("video_subfolder", "")
|
||||
vname = vid.get("video_name", "")
|
||||
vpath = (
|
||||
Path(config.comfy_output_path) / vsub / vname
|
||||
if vsub
|
||||
else Path(config.comfy_output_path) / vname
|
||||
)
|
||||
try:
|
||||
record_file(gen_id, vname, vpath.read_bytes())
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to record generation to DB: %s", exc)
|
||||
|
||||
# Flush auto-upload
|
||||
if config:
|
||||
from media_uploader import flush_pending
|
||||
asyncio.create_task(flush_pending(
|
||||
Path(config.comfy_output_path),
|
||||
config.media_upload_user,
|
||||
config.media_upload_pass,
|
||||
))
|
||||
|
||||
await bus.broadcast("queue_update", {
|
||||
"prompt_id": pid,
|
||||
"status": "complete",
|
||||
})
|
||||
await bus.broadcast_to_user(user_label, "generation_complete", {
|
||||
"prompt_id": pid,
|
||||
"seed": seed_used,
|
||||
"image_count": len(images),
|
||||
"video_count": len(videos),
|
||||
})
|
||||
|
||||
asyncio.create_task(_run())
|
||||
|
||||
return {
|
||||
"queued": True,
|
||||
"queue_position": depth + 1,
|
||||
"message": "Generation submitted to ComfyUI",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/workflow-gen")
|
||||
async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_auth)):
|
||||
"""Submit workflow-based generation(s) to ComfyUI."""
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI client not available")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
bus = get_bus()
|
||||
registry = get_user_registry()
|
||||
count = max(1, min(body.count, 20)) # cap at 20
|
||||
|
||||
async def _run_one():
|
||||
# Use the user's own state and template
|
||||
if registry:
|
||||
user_sm = registry.get_state_manager(user_label)
|
||||
user_template = registry.get_workflow_template(user_label)
|
||||
else:
|
||||
user_sm = comfy.state_manager
|
||||
user_template = comfy.workflow_manager.get_workflow_template()
|
||||
|
||||
if not user_template:
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": None, "error": "No workflow template loaded"
|
||||
})
|
||||
return
|
||||
|
||||
overrides = user_sm.get_overrides()
|
||||
if body.overrides:
|
||||
overrides = {**overrides, **body.overrides}
|
||||
|
||||
import uuid
|
||||
pid = str(uuid.uuid4())
|
||||
|
||||
def on_progress(node, pid_):
|
||||
asyncio.create_task(bus.broadcast("node_executing", {
|
||||
"node": node, "prompt_id": pid_
|
||||
}))
|
||||
|
||||
workflow, applied = comfy.inspector.inject_overrides(user_template, overrides)
|
||||
seed_used = applied.get("seed")
|
||||
comfy.last_seed = seed_used
|
||||
|
||||
try:
|
||||
images, videos = await comfy._general_generate(workflow, pid, on_progress)
|
||||
except Exception as exc:
|
||||
logger.exception("Workflow gen error")
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": None, "error": str(exc)
|
||||
})
|
||||
return
|
||||
|
||||
comfy.last_prompt_id = pid
|
||||
comfy.total_generated += 1
|
||||
|
||||
config = get_config()
|
||||
try:
|
||||
from generation_db import record_generation, record_file
|
||||
gen_id = record_generation(pid, "web", user_label, overrides, seed_used)
|
||||
for i, img_data in enumerate(images):
|
||||
record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||
if config and videos:
|
||||
for vid in videos:
|
||||
vsub = vid.get("video_subfolder", "")
|
||||
vname = vid.get("video_name", "")
|
||||
vpath = (
|
||||
Path(config.comfy_output_path) / vsub / vname
|
||||
if vsub
|
||||
else Path(config.comfy_output_path) / vname
|
||||
)
|
||||
try:
|
||||
record_file(gen_id, vname, vpath.read_bytes())
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to record generation to DB: %s", exc)
|
||||
|
||||
if config:
|
||||
from media_uploader import flush_pending
|
||||
asyncio.create_task(flush_pending(
|
||||
Path(config.comfy_output_path),
|
||||
config.media_upload_user,
|
||||
config.media_upload_pass,
|
||||
))
|
||||
|
||||
await bus.broadcast("queue_update", {"prompt_id": pid, "status": "complete"})
|
||||
await bus.broadcast_to_user(user_label, "generation_complete", {
|
||||
"prompt_id": pid,
|
||||
"seed": seed_used,
|
||||
"image_count": len(images),
|
||||
"video_count": len(videos),
|
||||
})
|
||||
|
||||
depth = await comfy.get_queue_depth()
|
||||
for _ in range(count):
|
||||
asyncio.create_task(_run_one())
|
||||
|
||||
return {
|
||||
"queued": True,
|
||||
"count": count,
|
||||
"queue_position": depth + 1,
|
||||
}
|
||||
143
web/routers/history_router.py
Normal file
143
web/routers/history_router.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""GET /api/history; GET /api/history/{prompt_id}/images; GET /api/history/{prompt_id}/file/{filename};
|
||||
POST /api/history/{prompt_id}/share; DELETE /api/history/{prompt_id}/share"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
from web.auth import require_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _assert_owns(prompt_id: str, user: dict) -> None:
|
||||
"""Raise 404 if the generation doesn't exist or doesn't belong to the user.
|
||||
|
||||
Returning the same 404 for both cases prevents leaking whether a
|
||||
prompt_id exists to users who don't own it. Admins bypass this check.
|
||||
"""
|
||||
if user.get("admin"):
|
||||
return
|
||||
from generation_db import get_generation
|
||||
gen = get_generation(prompt_id)
|
||||
if gen is None or gen["user_label"] != user["sub"]:
|
||||
raise HTTPException(404, "Not found")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_history(
|
||||
user: dict = Depends(require_auth),
|
||||
q: Optional[str] = Query(None, description="Keyword to search in overrides JSON"),
|
||||
):
|
||||
"""Return generation history. Admins see all; regular users see only their own.
|
||||
Pass ?q=keyword to filter by prompt text or any override field."""
|
||||
from generation_db import (
|
||||
get_history as db_get_history,
|
||||
get_history_for_user,
|
||||
search_history,
|
||||
search_history_for_user,
|
||||
)
|
||||
if q and q.strip():
|
||||
if user.get("admin"):
|
||||
return {"history": search_history(q.strip(), limit=50)}
|
||||
return {"history": search_history_for_user(user["sub"], q.strip(), limit=50)}
|
||||
if user.get("admin"):
|
||||
return {"history": db_get_history(limit=50)}
|
||||
return {"history": get_history_for_user(user["sub"], limit=50)}
|
||||
|
||||
|
||||
@router.get("/{prompt_id}/images")
|
||||
async def get_history_images(prompt_id: str, user: dict = Depends(require_auth)):
|
||||
"""
|
||||
Fetch output files for a past generation.
|
||||
|
||||
Returns base64-encoded blobs from the local SQLite DB — works even after
|
||||
``flush_pending`` has deleted the files from disk.
|
||||
"""
|
||||
_assert_owns(prompt_id, user)
|
||||
from generation_db import get_files
|
||||
files = get_files(prompt_id)
|
||||
if not files:
|
||||
raise HTTPException(404, f"No files found for prompt_id {prompt_id!r}")
|
||||
return {
|
||||
"prompt_id": prompt_id,
|
||||
"images": [
|
||||
{
|
||||
"filename": f["filename"],
|
||||
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
|
||||
"mime_type": f["mime_type"],
|
||||
}
|
||||
for f in files
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{prompt_id}/file/{filename}")
|
||||
async def get_history_file(
|
||||
prompt_id: str,
|
||||
filename: str,
|
||||
request: Request,
|
||||
user: dict = Depends(require_auth),
|
||||
):
|
||||
"""Stream a single output file, with HTTP range request support for video seeking."""
|
||||
_assert_owns(prompt_id, user)
|
||||
from generation_db import get_files
|
||||
files = get_files(prompt_id)
|
||||
matched = next((f for f in files if f["filename"] == filename), None)
|
||||
if matched is None:
|
||||
raise HTTPException(404, f"File {filename!r} not found for prompt_id {prompt_id!r}")
|
||||
|
||||
data: bytes = matched["data"]
|
||||
mime: str = matched["mime_type"]
|
||||
total = len(data)
|
||||
|
||||
range_header = request.headers.get("range")
|
||||
if range_header:
|
||||
range_val = range_header.replace("bytes=", "")
|
||||
start_str, _, end_str = range_val.partition("-")
|
||||
start = int(start_str) if start_str else 0
|
||||
end = int(end_str) if end_str else total - 1
|
||||
end = min(end, total - 1)
|
||||
chunk = data[start : end + 1]
|
||||
return Response(
|
||||
content=chunk,
|
||||
status_code=206,
|
||||
media_type=mime,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{total}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(len(chunk)),
|
||||
},
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=data,
|
||||
media_type=mime,
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(total),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{prompt_id}/share")
|
||||
async def create_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
|
||||
"""Create a share token for a generation. Only the owner may share."""
|
||||
# Use the same 404-for-everything helper to avoid leaking prompt_id existence
|
||||
_assert_owns(prompt_id, user)
|
||||
from generation_db import create_share as db_create_share
|
||||
token = db_create_share(prompt_id, user["sub"])
|
||||
return {"share_token": token}
|
||||
|
||||
|
||||
@router.delete("/{prompt_id}/share")
|
||||
async def revoke_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
|
||||
"""Revoke a share token for a generation. Only the owner may revoke."""
|
||||
from generation_db import revoke_share as db_revoke_share
|
||||
deleted = db_revoke_share(prompt_id, user["sub"])
|
||||
if not deleted:
|
||||
raise HTTPException(404, "No active share found for this generation")
|
||||
return {"ok": True}
|
||||
194
web/routers/inputs_router.py
Normal file
194
web/routers/inputs_router.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""GET/POST/DELETE /api/inputs; GET /api/inputs/{id}/image; POST /api/inputs/{id}/activate"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_config, get_user_registry
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_inputs(_: dict = Depends(require_auth)):
|
||||
"""List all input images (Discord + web uploads)."""
|
||||
from input_image_db import get_all_images
|
||||
rows = get_all_images()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def upload_input(
|
||||
file: UploadFile = File(...),
|
||||
slot_key: Optional[str] = Form(default=None),
|
||||
user: dict = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
Upload an input image.
|
||||
|
||||
Stores image bytes directly in SQLite. If *slot_key* is provided the
|
||||
image is immediately activated for that slot (writes to ComfyUI input
|
||||
folder and updates the user's state override).
|
||||
|
||||
The physical slot file uses a namespaced key ``<user_label>_<slot_key>``
|
||||
so concurrent users each get their own active image file.
|
||||
"""
|
||||
config = get_config()
|
||||
if config is None:
|
||||
raise HTTPException(503, "Config not available")
|
||||
|
||||
data = await file.read()
|
||||
filename = file.filename or "upload.png"
|
||||
|
||||
from input_image_db import upsert_image, activate_image_for_slot
|
||||
row_id = upsert_image(
|
||||
original_message_id=0, # sentinel for web uploads
|
||||
bot_reply_id=0,
|
||||
channel_id=0,
|
||||
filename=filename,
|
||||
image_data=data,
|
||||
)
|
||||
|
||||
activated_filename: str | None = None
|
||||
if slot_key:
|
||||
user_label: str = user["sub"]
|
||||
namespaced_key = f"{user_label}_{slot_key}"
|
||||
activated_filename = activate_image_for_slot(
|
||||
row_id, namespaced_key, config.comfy_input_path
|
||||
)
|
||||
registry = get_user_registry()
|
||||
if registry:
|
||||
registry.get_state_manager(user_label).set_override(slot_key, activated_filename)
|
||||
else:
|
||||
from web.deps import get_comfy
|
||||
comfy = get_comfy()
|
||||
if comfy:
|
||||
comfy.state_manager.set_override(slot_key, activated_filename)
|
||||
|
||||
return {"id": row_id, "filename": filename, "slot_key": slot_key, "activated_filename": activated_filename}
|
||||
|
||||
|
||||
@router.post("/{row_id}/activate")
|
||||
async def activate_input(
|
||||
row_id: int,
|
||||
slot_key: str = Body(default="input_image", embed=True),
|
||||
user: dict = Depends(require_auth),
|
||||
):
|
||||
"""Write the stored image to the ComfyUI input folder and set the user's slot override."""
|
||||
config = get_config()
|
||||
if config is None:
|
||||
raise HTTPException(503, "Config not available")
|
||||
|
||||
from input_image_db import get_image, activate_image_for_slot
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
namespaced_key = f"{user_label}_{slot_key}"
|
||||
|
||||
try:
|
||||
filename = activate_image_for_slot(row_id, namespaced_key, config.comfy_input_path)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(409, str(exc))
|
||||
|
||||
registry = get_user_registry()
|
||||
if registry:
|
||||
registry.get_state_manager(user_label).set_override(slot_key, filename)
|
||||
else:
|
||||
from web.deps import get_comfy
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "State manager not available")
|
||||
comfy.state_manager.set_override(slot_key, filename)
|
||||
|
||||
return {"ok": True, "slot_key": slot_key, "filename": filename}
|
||||
|
||||
|
||||
@router.delete("/{row_id}")
|
||||
async def delete_input(row_id: int, _: dict = Depends(require_auth)):
|
||||
"""Delete an input image record (and its active slot file if applicable)."""
|
||||
from input_image_db import get_image, delete_image
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
config = get_config()
|
||||
delete_image(row_id, comfy_input_path=config.comfy_input_path if config else None)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/{row_id}/image")
|
||||
async def get_input_image(row_id: int, _: dict = Depends(require_auth)):
|
||||
"""Serve the raw image bytes stored in the database for a given input image row."""
|
||||
from input_image_db import get_image, get_image_data
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
data = get_image_data(row_id)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
||||
|
||||
mime, _ = mimetypes.guess_type(row["filename"])
|
||||
return Response(content=data, media_type=mime or "application/octet-stream")
|
||||
|
||||
|
||||
def _pil_resize_response(data: bytes, filename: str, max_size: int, quality: int) -> Response:
|
||||
"""Resize image bytes with Pillow and return a JPEG Response. Raises on failure."""
|
||||
import io
|
||||
from PIL import Image as _PIL
|
||||
img = _PIL.open(io.BytesIO(data))
|
||||
img.thumbnail((max_size, max_size), _PIL.LANCZOS)
|
||||
buf = io.BytesIO()
|
||||
img.convert("RGB").save(buf, "JPEG", quality=quality, optimize=True)
|
||||
return Response(
|
||||
content=buf.getvalue(),
|
||||
media_type="image/jpeg",
|
||||
headers={"Cache-Control": "public, max-age=86400"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{row_id}/thumb")
|
||||
async def get_input_thumb(row_id: int, _: dict = Depends(require_auth)):
|
||||
"""Serve a small compressed thumbnail (max 200 px, JPEG 65 %) for fast previews."""
|
||||
from input_image_db import get_image, get_image_data
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
data = get_image_data(row_id)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
||||
|
||||
try:
|
||||
return _pil_resize_response(data, row["filename"], max_size=200, quality=65)
|
||||
except Exception:
|
||||
mime, _ = mimetypes.guess_type(row["filename"])
|
||||
return Response(content=data, media_type=mime or "application/octet-stream")
|
||||
|
||||
|
||||
@router.get("/{row_id}/mid")
|
||||
async def get_input_mid(row_id: int, _: dict = Depends(require_auth)):
|
||||
"""Serve a medium compressed image (max 800 px, JPEG 80 %) for progressive loading."""
|
||||
from input_image_db import get_image, get_image_data
|
||||
row = get_image(row_id)
|
||||
if row is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
data = get_image_data(row_id)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
||||
|
||||
try:
|
||||
return _pil_resize_response(data, row["filename"], max_size=800, quality=80)
|
||||
except Exception:
|
||||
mime, _ = mimetypes.guess_type(row["filename"])
|
||||
return Response(content=data, media_type=mime or "application/octet-stream")
|
||||
153
web/routers/presets_router.py
Normal file
153
web/routers/presets_router.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""CRUD for workflow presets via /api/presets"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_comfy, get_user_registry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SavePresetRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class SaveFromHistoryRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
def _get_pm():
|
||||
from web.deps import get_bot
|
||||
bot = get_bot()
|
||||
pm = getattr(bot, "preset_manager", None) if bot else None
|
||||
if pm is None:
|
||||
from preset_manager import PresetManager
|
||||
pm = PresetManager()
|
||||
return pm
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_presets(_: dict = Depends(require_auth)):
|
||||
pm = _get_pm()
|
||||
return {"presets": pm.list_preset_details()}
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def save_preset(body: SavePresetRequest, user: dict = Depends(require_auth)):
|
||||
"""Capture the user's overrides + workflow template as a named preset."""
|
||||
user_label: str = user["sub"]
|
||||
registry = get_user_registry()
|
||||
pm = _get_pm()
|
||||
|
||||
if registry:
|
||||
workflow_template = registry.get_workflow_template(user_label)
|
||||
overrides = registry.get_state_manager(user_label).get_overrides()
|
||||
else:
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI not available")
|
||||
workflow_template = comfy.get_workflow_template()
|
||||
overrides = comfy.state_manager.get_overrides()
|
||||
|
||||
try:
|
||||
pm.save(body.name, workflow_template, overrides, owner=user_label, description=body.description)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
return {"ok": True, "name": body.name}
|
||||
|
||||
|
||||
@router.get("/{name}")
|
||||
async def get_preset(name: str, _: dict = Depends(require_auth)):
|
||||
pm = _get_pm()
|
||||
data = pm.load(name)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Preset not found")
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/{name}/load")
|
||||
async def load_preset(name: str, user: dict = Depends(require_auth)):
|
||||
"""Restore overrides (and optionally workflow template) from a preset into the user's state."""
|
||||
pm = _get_pm()
|
||||
data = pm.load(name)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Preset not found")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
registry = get_user_registry()
|
||||
|
||||
if registry:
|
||||
wf = data.get("workflow")
|
||||
if wf:
|
||||
registry.set_workflow(user_label, wf, name)
|
||||
else:
|
||||
# No workflow in preset — just clear overrides and restore state
|
||||
registry.get_state_manager(user_label).clear_overrides()
|
||||
state = data.get("state", {})
|
||||
sm = registry.get_state_manager(user_label)
|
||||
for k, v in state.items():
|
||||
if v is not None:
|
||||
sm.set_override(k, v)
|
||||
else:
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI not available")
|
||||
comfy.state_manager.clear_overrides()
|
||||
state = data.get("state", {})
|
||||
for k, v in state.items():
|
||||
if v is not None:
|
||||
comfy.state_manager.set_override(k, v)
|
||||
wf = data.get("workflow")
|
||||
if wf:
|
||||
comfy.workflow_manager.set_workflow_template(wf)
|
||||
|
||||
return {"ok": True, "name": name, "overrides_restored": list(data.get("state", {}).keys())}
|
||||
|
||||
|
||||
@router.delete("/{name}")
|
||||
async def delete_preset(name: str, user: dict = Depends(require_auth)):
|
||||
pm = _get_pm()
|
||||
data = pm.load(name)
|
||||
if data is None:
|
||||
raise HTTPException(404, "Preset not found")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
is_admin = user.get("admin") is True
|
||||
owner = data.get("owner")
|
||||
if owner is not None and owner != user_label and not is_admin:
|
||||
raise HTTPException(403, "You do not have permission to delete this preset")
|
||||
|
||||
pm.delete(name)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/from-history/{prompt_id}")
|
||||
async def save_preset_from_history(
|
||||
prompt_id: str,
|
||||
body: SaveFromHistoryRequest,
|
||||
user: dict = Depends(require_auth),
|
||||
):
|
||||
"""Create a preset from a past generation's overrides."""
|
||||
from generation_db import get_generation_full
|
||||
|
||||
gen = get_generation_full(prompt_id)
|
||||
if gen is None:
|
||||
raise HTTPException(404, "Generation not found")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
is_admin = user.get("admin") is True
|
||||
if not is_admin and gen.get("user_label") != user_label:
|
||||
raise HTTPException(404, "Generation not found")
|
||||
|
||||
pm = _get_pm()
|
||||
try:
|
||||
pm.save(body.name, None, gen["overrides"], owner=user_label, description=body.description)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400, str(exc))
|
||||
return {"ok": True, "name": body.name}
|
||||
90
web/routers/server_router.py
Normal file
90
web/routers/server_router.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""GET/POST /api/server/{action}; GET /api/logs/tail"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_config, get_comfy
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/server/status")
|
||||
async def server_status(_: dict = Depends(require_auth)):
|
||||
"""Return NSSM service state and HTTP health."""
|
||||
config = get_config()
|
||||
if config is None:
|
||||
raise HTTPException(503, "Config not available")
|
||||
from commands.server import get_service_state
|
||||
import asyncio
|
||||
|
||||
async def _false():
|
||||
return False
|
||||
|
||||
comfy = get_comfy()
|
||||
service_state, http_ok = await asyncio.gather(
|
||||
get_service_state(config.comfy_service_name),
|
||||
comfy.check_connection() if comfy else _false(),
|
||||
)
|
||||
return {"service_state": service_state, "http_reachable": http_ok}
|
||||
|
||||
|
||||
@router.post("/server/{action}")
|
||||
async def server_action(action: str, _: dict = Depends(require_auth)):
|
||||
"""Control the ComfyUI service: start | stop | restart | install | uninstall"""
|
||||
config = get_config()
|
||||
if config is None:
|
||||
raise HTTPException(503, "Config not available")
|
||||
|
||||
valid_actions = {"start", "stop", "restart", "install", "uninstall"}
|
||||
if action not in valid_actions:
|
||||
raise HTTPException(400, f"Invalid action '{action}'")
|
||||
|
||||
from commands.server import _nssm, _install_service
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
if action == "install":
|
||||
ok, msg = await _install_service(config)
|
||||
if not ok:
|
||||
raise HTTPException(500, msg)
|
||||
elif action == "uninstall":
|
||||
await _nssm("stop", config.comfy_service_name)
|
||||
await _nssm("remove", config.comfy_service_name, "confirm")
|
||||
elif action == "start":
|
||||
await _nssm("start", config.comfy_service_name)
|
||||
elif action == "stop":
|
||||
await _nssm("stop", config.comfy_service_name)
|
||||
elif action == "restart":
|
||||
await _nssm("restart", config.comfy_service_name)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, str(exc))
|
||||
|
||||
return {"ok": True, "action": action}
|
||||
|
||||
|
||||
@router.get("/logs/tail")
|
||||
async def tail_logs(lines: int = 100, _: dict = Depends(require_auth)):
|
||||
"""Tail the ComfyUI log file."""
|
||||
config = get_config()
|
||||
if config is None or not config.comfy_log_dir:
|
||||
raise HTTPException(503, "Log directory not configured")
|
||||
|
||||
log_dir = Path(config.comfy_log_dir)
|
||||
log_file = log_dir / "comfyui.log"
|
||||
if not log_file.exists():
|
||||
return {"lines": []}
|
||||
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
|
||||
all_lines = f.readlines()
|
||||
tail = all_lines[-min(lines, len(all_lines)):]
|
||||
return {"lines": [ln.rstrip("\n") for ln in tail]}
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, str(exc))
|
||||
85
web/routers/share_router.py
Normal file
85
web/routers/share_router.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""GET /api/share/{token}; GET /api/share/{token}/file/{filename}"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
from web.auth import require_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{token}")
|
||||
async def get_share(token: str, _: dict = Depends(require_auth)):
|
||||
"""Fetch share metadata and images. Any authenticated user may view a valid share link."""
|
||||
from generation_db import get_share_by_token, get_files
|
||||
gen = get_share_by_token(token)
|
||||
if gen is None:
|
||||
raise HTTPException(404, "Share not found or revoked")
|
||||
files = get_files(gen["prompt_id"])
|
||||
return {
|
||||
"prompt_id": gen["prompt_id"],
|
||||
"created_at": gen["created_at"],
|
||||
"overrides": gen["overrides"],
|
||||
"seed": gen["seed"],
|
||||
"images": [
|
||||
{
|
||||
"filename": f["filename"],
|
||||
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
|
||||
"mime_type": f["mime_type"],
|
||||
}
|
||||
for f in files
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{token}/file/{filename}")
|
||||
async def get_share_file(
|
||||
token: str,
|
||||
filename: str,
|
||||
request: Request,
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Stream a single output file via share token, with HTTP range support for video seeking."""
|
||||
from generation_db import get_share_by_token, get_files
|
||||
gen = get_share_by_token(token)
|
||||
if gen is None:
|
||||
raise HTTPException(404, "Share not found or revoked")
|
||||
files = get_files(gen["prompt_id"])
|
||||
matched = next((f for f in files if f["filename"] == filename), None)
|
||||
if matched is None:
|
||||
raise HTTPException(404, f"File {filename!r} not found")
|
||||
|
||||
data: bytes = matched["data"]
|
||||
mime: str = matched["mime_type"]
|
||||
total = len(data)
|
||||
|
||||
range_header = request.headers.get("range")
|
||||
if range_header:
|
||||
range_val = range_header.replace("bytes=", "")
|
||||
start_str, _, end_str = range_val.partition("-")
|
||||
start = int(start_str) if start_str else 0
|
||||
end = int(end_str) if end_str else total - 1
|
||||
end = min(end, total - 1)
|
||||
chunk = data[start : end + 1]
|
||||
return Response(
|
||||
content=chunk,
|
||||
status_code=206,
|
||||
media_type=mime,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{total}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(len(chunk)),
|
||||
},
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=data,
|
||||
media_type=mime,
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(total),
|
||||
},
|
||||
)
|
||||
53
web/routers/state_router.py
Normal file
53
web/routers/state_router.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""GET/PUT /api/state; DELETE /api/state/{key}"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_config, get_user_registry
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _get_user_sm(user: dict):
|
||||
"""Return the per-user WorkflowStateManager, raising 503 if unavailable."""
|
||||
registry = get_user_registry()
|
||||
if registry is None:
|
||||
raise HTTPException(503, "State manager not available")
|
||||
return registry.get_state_manager(user["sub"])
|
||||
|
||||
|
||||
@router.get("/state")
|
||||
async def get_state(user: dict = Depends(require_auth)):
|
||||
"""Return all current overrides for the authenticated user."""
|
||||
sm = _get_user_sm(user)
|
||||
return sm.get_overrides()
|
||||
|
||||
|
||||
@router.put("/state")
|
||||
async def put_state(body: Dict[str, Any], user: dict = Depends(require_auth)):
|
||||
"""Merge override values. Pass ``null`` as a value to delete a key."""
|
||||
sm = _get_user_sm(user)
|
||||
for key, value in body.items():
|
||||
if value is None:
|
||||
sm.delete_override(key)
|
||||
else:
|
||||
sm.set_override(key, value)
|
||||
return sm.get_overrides()
|
||||
|
||||
|
||||
@router.delete("/state/{key}")
|
||||
async def delete_state_key(key: str, user: dict = Depends(require_auth)):
|
||||
"""Remove a single override key, and clean up any associated slot file."""
|
||||
sm = _get_user_sm(user)
|
||||
sm.delete_override(key)
|
||||
|
||||
config = get_config()
|
||||
if config:
|
||||
from input_image_db import deactivate_image_slot
|
||||
user_label: str = user["sub"]
|
||||
deactivate_image_slot(f"{user_label}_{key}", config.comfy_input_path)
|
||||
|
||||
return {"ok": True, "key": key}
|
||||
74
web/routers/status_router.py
Normal file
74
web/routers/status_router.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""GET /api/status — polling fallback for clients that can't use WebSocket"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_bot, get_comfy, get_config
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_status(_: dict = Depends(require_auth)):
|
||||
"""Return a full status snapshot."""
|
||||
bot = get_bot()
|
||||
comfy = get_comfy()
|
||||
config = get_config()
|
||||
|
||||
snap: dict = {}
|
||||
|
||||
if bot is not None:
|
||||
import datetime as _dt
|
||||
lat = bot.latency
|
||||
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
|
||||
start = getattr(bot, "start_time", None)
|
||||
uptime = ""
|
||||
if start:
|
||||
delta = _dt.datetime.now(_dt.timezone.utc) - start
|
||||
total = int(delta.total_seconds())
|
||||
h, rem = divmod(total, 3600)
|
||||
m, s = divmod(rem, 60)
|
||||
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
|
||||
snap["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
|
||||
|
||||
if comfy is not None:
|
||||
q_task = asyncio.create_task(comfy.get_comfy_queue())
|
||||
conn_task = asyncio.create_task(comfy.check_connection())
|
||||
q, reachable = await asyncio.gather(q_task, conn_task)
|
||||
|
||||
pending = len(q.get("queue_pending", [])) if q else 0
|
||||
running = len(q.get("queue_running", [])) if q else 0
|
||||
wm = getattr(comfy, "workflow_manager", None)
|
||||
wf_loaded = wm is not None and wm.get_workflow_template() is not None
|
||||
|
||||
snap["comfy"] = {
|
||||
"server": comfy.server_address,
|
||||
"reachable": reachable,
|
||||
"queue_pending": pending,
|
||||
"queue_running": running,
|
||||
"workflow_loaded": wf_loaded,
|
||||
"last_seed": comfy.last_seed,
|
||||
"total_generated": comfy.total_generated,
|
||||
}
|
||||
snap["overrides"] = comfy.state_manager.get_overrides()
|
||||
|
||||
if config is not None:
|
||||
from commands.server import get_service_state
|
||||
service_state = await get_service_state(config.comfy_service_name)
|
||||
snap["service"] = {"state": service_state}
|
||||
|
||||
try:
|
||||
from media_uploader import get_stats as us_fn, is_running as ur_fn
|
||||
us = us_fn()
|
||||
snap["upload"] = {
|
||||
"configured": bool(config.media_upload_user),
|
||||
"running": ur_fn(),
|
||||
"total_ok": us.total_ok,
|
||||
"total_fail": us.total_fail,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return snap
|
||||
175
web/routers/workflow_router.py
Normal file
175
web/routers/workflow_router.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
GET /api/workflow — current workflow info
|
||||
GET /api/workflow/inputs — dynamic NodeInput list
|
||||
GET /api/workflow/files — list files in workflows/
|
||||
POST /api/workflow/upload — upload a workflow JSON
|
||||
POST /api/workflow/load — load a workflow from workflows/
|
||||
GET /api/workflow/models?type=checkpoints|loras — available models (60s TTL cache)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.deps import get_comfy, get_config, get_inspector, get_user_registry
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
_WORKFLOWS_DIR = _PROJECT_ROOT / "workflows"
|
||||
|
||||
# Simple in-memory TTL cache for models
|
||||
_models_cache: dict = {}
|
||||
_MODELS_TTL = 60.0
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_workflow(user: dict = Depends(require_auth)):
|
||||
"""Return basic info about the currently loaded workflow."""
|
||||
user_label: str = user["sub"]
|
||||
registry = get_user_registry()
|
||||
|
||||
if registry:
|
||||
template = registry.get_workflow_template(user_label)
|
||||
last_wf = registry.get_state_manager(user_label).get_last_workflow_file()
|
||||
else:
|
||||
# Fallback to global state when registry is unavailable
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "Workflow manager not available")
|
||||
template = comfy.workflow_manager.get_workflow_template()
|
||||
last_wf = comfy.state_manager.get_last_workflow_file()
|
||||
|
||||
return {
|
||||
"loaded": template is not None,
|
||||
"node_count": len(template) if template else 0,
|
||||
"last_workflow_file": last_wf,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/inputs")
|
||||
async def get_workflow_inputs(user: dict = Depends(require_auth)):
|
||||
"""Return dynamic NodeInput list (common + advanced) for the current workflow."""
|
||||
user_label: str = user["sub"]
|
||||
inspector = get_inspector()
|
||||
if inspector is None:
|
||||
raise HTTPException(503, "Workflow components not available")
|
||||
|
||||
registry = get_user_registry()
|
||||
if registry:
|
||||
template = registry.get_workflow_template(user_label)
|
||||
overrides = registry.get_state_manager(user_label).get_overrides()
|
||||
else:
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "Workflow components not available")
|
||||
template = comfy.workflow_manager.get_workflow_template()
|
||||
overrides = comfy.state_manager.get_overrides()
|
||||
|
||||
if template is None:
|
||||
return {"common": [], "advanced": []}
|
||||
|
||||
inputs = inspector.inspect(template)
|
||||
result = []
|
||||
for ni in inputs:
|
||||
val = overrides.get(ni.key, ni.current_value)
|
||||
result.append({
|
||||
"key": ni.key,
|
||||
"label": ni.label,
|
||||
"input_type": ni.input_type,
|
||||
"current_value": val,
|
||||
"node_class": ni.node_class,
|
||||
"node_title": ni.node_title,
|
||||
"is_common": ni.is_common,
|
||||
})
|
||||
common = [r for r in result if r["is_common"]]
|
||||
advanced = [r for r in result if not r["is_common"]]
|
||||
return {"common": common, "advanced": advanced}
|
||||
|
||||
|
||||
@router.get("/files")
|
||||
async def list_workflow_files(_: dict = Depends(require_auth)):
|
||||
"""List .json files in the workflows/ folder."""
|
||||
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
files = sorted(p.name for p in _WORKFLOWS_DIR.glob("*.json"))
|
||||
return {"files": files}
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_workflow(
|
||||
file: UploadFile = File(...),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Upload a workflow JSON to the workflows/ folder."""
|
||||
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
filename = file.filename or "workflow.json"
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
data = await file.read()
|
||||
try:
|
||||
json.loads(data) # validate JSON
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(400, f"Invalid JSON: {exc}")
|
||||
dest = _WORKFLOWS_DIR / filename
|
||||
dest.write_bytes(data)
|
||||
return {"ok": True, "filename": filename}
|
||||
|
||||
|
||||
@router.post("/load")
|
||||
async def load_workflow(filename: str = Form(...), user: dict = Depends(require_auth)):
|
||||
"""Load a workflow from the workflows/ folder into the user's isolated state."""
|
||||
wf_path = _WORKFLOWS_DIR / filename
|
||||
if not wf_path.exists():
|
||||
raise HTTPException(404, f"Workflow file '{filename}' not found")
|
||||
try:
|
||||
with open(wf_path, "r", encoding="utf-8") as f:
|
||||
workflow = json.load(f)
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, str(exc))
|
||||
|
||||
registry = get_user_registry()
|
||||
if registry:
|
||||
registry.set_workflow(user["sub"], workflow, filename)
|
||||
else:
|
||||
# Fallback: update global state when registry unavailable
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI not available")
|
||||
comfy.workflow_manager.set_workflow_template(workflow)
|
||||
comfy.state_manager.clear_overrides()
|
||||
comfy.state_manager.set_last_workflow_file(filename)
|
||||
|
||||
inspector = get_inspector()
|
||||
node_count = len(workflow)
|
||||
inputs_count = len(inspector.inspect(workflow)) if inspector else 0
|
||||
return {
|
||||
"ok": True,
|
||||
"filename": filename,
|
||||
"node_count": node_count,
|
||||
"inputs_count": inputs_count,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_models(type: str = "checkpoints", _: dict = Depends(require_auth)):
|
||||
"""Return available model names from ComfyUI (60s TTL cache)."""
|
||||
global _models_cache
|
||||
now = time.time()
|
||||
cache_key = type
|
||||
cached = _models_cache.get(cache_key)
|
||||
if cached and (now - cached["ts"]) < _MODELS_TTL:
|
||||
return {"type": type, "models": cached["models"]}
|
||||
|
||||
comfy = get_comfy()
|
||||
if comfy is None:
|
||||
raise HTTPException(503, "ComfyUI not available")
|
||||
models = await comfy.get_models(type)
|
||||
_models_cache[cache_key] = {"models": models, "ts": now}
|
||||
return {"type": type, "models": models}
|
||||
61
web/routers/ws_router.py
Normal file
61
web/routers/ws_router.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""WebSocket /ws?token=<jwt> — real-time event stream"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from web.auth import verify_ws_token
|
||||
from web.ws_bus import get_bus
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: str = ""):
|
||||
"""
|
||||
Authenticate via JWT query param or ttb_session cookie, then stream events from WSBus.
|
||||
|
||||
Events common to all users: status_snapshot, queue_update, node_executing, server_state
|
||||
Events private to submitter: generation_complete, generation_error
|
||||
"""
|
||||
payload = verify_ws_token(token)
|
||||
if payload is None:
|
||||
# Fallback: browsers send cookies automatically with WebSocket connections
|
||||
cookie_token = websocket.cookies.get("ttb_session", "")
|
||||
payload = verify_ws_token(cookie_token)
|
||||
if payload is None:
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
user_label: str = payload.get("sub", "anonymous")
|
||||
bus = get_bus()
|
||||
queue = bus.subscribe(user_label)
|
||||
await websocket.accept()
|
||||
logger.info("WS connected: user=%s", user_label)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for an event from the bus
|
||||
try:
|
||||
frame = await asyncio.wait_for(queue.get(), timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
# Send a keepalive ping
|
||||
try:
|
||||
await websocket.send_text('{"type":"ping"}')
|
||||
except Exception:
|
||||
break
|
||||
continue
|
||||
|
||||
try:
|
||||
await websocket.send_text(frame)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
bus.unsubscribe(user_label, queue)
|
||||
logger.info("WS disconnected: user=%s", user_label)
|
||||
Reference in New Issue
Block a user