Files
comfy-discord-web/web/routers/inputs_router.py
Khoa (Revenovich) Tran Gia 6004b000a7 manual submit
2026-03-07 21:49:16 +07:00

256 lines
9.0 KiB
Python

"""GET/POST/DELETE /api/inputs; GET /api/inputs/{id}/image; POST /api/inputs/{id}/activate"""
from __future__ import annotations
import logging
import mimetypes
import re
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import Response
from web.auth import require_auth
from web.deps import get_config, get_user_registry
router = APIRouter()
logger = logging.getLogger(__name__)
_ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
_MAX_UPLOAD_BYTES = 50 * 1024 * 1024 # 50 MB
_SAFE_SLOT_RE = re.compile(r'^[a-zA-Z0-9_\-]+$')
def _validate_slot_key(slot_key: str) -> None:
if not _SAFE_SLOT_RE.match(slot_key):
raise HTTPException(400, "slot_key may only contain letters, digits, hyphens, underscores")
@router.get("")
async def list_inputs(
_: dict = Depends(require_auth),
persons: list[str] = Query(default=[], alias="persons", description="Filter by person name/alias substring (repeatable)"),
):
"""List all input images (Discord + web uploads). Optionally filter by persons."""
active_persons = [p.strip() for p in persons if p.strip()]
if active_persons:
import face_db as face_db_mod
from input_image_db import get_images_by_ids
all_ids: set[int] = set()
for p in active_persons:
ids = face_db_mod.get_source_ids_for_person_query(p, "input")
all_ids.update(ids)
images = list(get_images_by_ids(list(all_ids))) if all_ids else []
if images:
person_map = face_db_mod.get_persons_for_source_id_map(
[img["id"] for img in images], "input"
)
for img in images:
img["detected_persons"] = person_map.get(img["id"], [])
return images
from input_image_db import get_all_images
rows = get_all_images()
return [dict(r) for r in rows]
@router.post("")
async def upload_input(
file: UploadFile = File(...),
slot_key: Optional[str] = Form(default=None),
user: dict = Depends(require_auth),
):
"""
Upload an input image.
Stores image bytes directly in SQLite. If *slot_key* is provided the
image is immediately activated for that slot (writes to ComfyUI input
folder and updates the user's state override).
The physical slot file uses a namespaced key ``<user_label>_<slot_key>``
so concurrent users each get their own active image file.
"""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
if slot_key:
_validate_slot_key(slot_key)
data = await file.read()
filename = file.filename or "upload.png"
ext = Path(filename).suffix.lower()
if ext not in _ALLOWED_EXTS:
raise HTTPException(415, f"Unsupported file type '{ext}'. Allowed: {sorted(_ALLOWED_EXTS)}")
if len(data) > _MAX_UPLOAD_BYTES:
raise HTTPException(413, "File too large (max 50 MB)")
from input_image_db import upsert_image, activate_image_for_slot
row_id = upsert_image(
original_message_id=0, # sentinel for web uploads
bot_reply_id=0,
channel_id=0,
filename=filename,
image_data=data,
)
activated_filename: str | None = None
if slot_key:
user_label: str = user["sub"]
namespaced_key = f"{user_label}_{slot_key}"
activated_filename = activate_image_for_slot(
row_id, namespaced_key, config.comfy_input_path
)
registry = get_user_registry()
if registry:
registry.get_state_manager(user_label).set_override(slot_key, activated_filename)
else:
from web.deps import get_comfy
comfy = get_comfy()
if comfy:
comfy.state_manager.set_override(slot_key, activated_filename)
# Face scan — runs synchronously here (~1-2 s); unknown faces returned to UI
pending_faces: list[dict] = []
try:
from face_service import get_face_service
import face_db as _face_db
_face_db.init_db()
svc = get_face_service()
if svc.available:
results = await svc.scan_input_image(row_id, data)
pending_faces = [
{"detection_id": r.detection_id, "face_index": r.face_index, "bbox": r.bbox}
for r in results
if r.matched_person_id is None
]
except Exception as exc:
logger.warning("Face scan failed for upload row_id=%d: %s", row_id, exc)
return {
"id": row_id,
"filename": filename,
"slot_key": slot_key,
"activated_filename": activated_filename,
"pending_faces": pending_faces,
}
@router.post("/{row_id}/activate")
async def activate_input(
row_id: int,
slot_key: str = Body(default="input_image", embed=True),
user: dict = Depends(require_auth),
):
"""Write the stored image to the ComfyUI input folder and set the user's slot override."""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
from input_image_db import get_image, activate_image_for_slot
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
user_label: str = user["sub"]
_validate_slot_key(slot_key)
namespaced_key = f"{user_label}_{slot_key}"
try:
filename = activate_image_for_slot(row_id, namespaced_key, config.comfy_input_path)
except ValueError as exc:
raise HTTPException(409, str(exc))
registry = get_user_registry()
if registry:
registry.get_state_manager(user_label).set_override(slot_key, filename)
else:
from web.deps import get_comfy
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "State manager not available")
comfy.state_manager.set_override(slot_key, filename)
return {"ok": True, "slot_key": slot_key, "filename": filename}
@router.delete("/{row_id}")
async def delete_input(row_id: int, _: dict = Depends(require_auth)):
"""Delete an input image record (and its active slot file if applicable)."""
from input_image_db import get_image, delete_image
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
config = get_config()
delete_image(row_id, comfy_input_path=config.comfy_input_path if config else None)
return {"ok": True}
@router.get("/{row_id}/image")
async def get_input_image(row_id: int, _: dict = Depends(require_auth)):
"""Serve the raw image bytes stored in the database for a given input image row."""
from input_image_db import get_image, get_image_data
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
data = get_image_data(row_id)
if data is None:
raise HTTPException(404, "Image data not available — re-upload to backfill")
mime, _ = mimetypes.guess_type(row["filename"])
return Response(content=data, media_type=mime or "application/octet-stream")
def _pil_resize_response(data: bytes, filename: str, max_size: int, quality: int) -> Response:
"""Resize image bytes with Pillow and return a JPEG Response. Raises on failure."""
import io
from PIL import Image as _PIL
img = _PIL.open(io.BytesIO(data))
img.thumbnail((max_size, max_size), _PIL.LANCZOS)
buf = io.BytesIO()
img.convert("RGB").save(buf, "JPEG", quality=quality, optimize=True)
return Response(
content=buf.getvalue(),
media_type="image/jpeg",
headers={"Cache-Control": "public, max-age=86400"},
)
@router.get("/{row_id}/thumb")
async def get_input_thumb(row_id: int, _: dict = Depends(require_auth)):
"""Serve a small compressed thumbnail (max 200 px, JPEG 65 %) for fast previews."""
from input_image_db import get_image, get_image_data
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
data = get_image_data(row_id)
if data is None:
raise HTTPException(404, "Image data not available — re-upload to backfill")
try:
return _pil_resize_response(data, row["filename"], max_size=200, quality=65)
except Exception:
mime, _ = mimetypes.guess_type(row["filename"])
return Response(content=data, media_type=mime or "application/octet-stream")
@router.get("/{row_id}/mid")
async def get_input_mid(row_id: int, _: dict = Depends(require_auth)):
"""Serve a medium compressed image (max 800 px, JPEG 80 %) for progressive loading."""
from input_image_db import get_image, get_image_data
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
data = get_image_data(row_id)
if data is None:
raise HTTPException(404, "Image data not available — re-upload to backfill")
try:
return _pil_resize_response(data, row["filename"], max_size=800, quality=80)
except Exception:
mime, _ = mimetypes.guess_type(row["filename"])
return Response(content=data, media_type=mime or "application/octet-stream")