256 lines
9.0 KiB
Python
256 lines
9.0 KiB
Python
"""GET/POST/DELETE /api/inputs; GET /api/inputs/{id}/image; POST /api/inputs/{id}/activate"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import mimetypes
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, Query, UploadFile
|
|
from fastapi.responses import Response
|
|
|
|
from web.auth import require_auth
|
|
from web.deps import get_config, get_user_registry
|
|
|
|
router = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
|
_MAX_UPLOAD_BYTES = 50 * 1024 * 1024 # 50 MB
|
|
_SAFE_SLOT_RE = re.compile(r'^[a-zA-Z0-9_\-]+$')
|
|
|
|
|
|
def _validate_slot_key(slot_key: str) -> None:
|
|
if not _SAFE_SLOT_RE.match(slot_key):
|
|
raise HTTPException(400, "slot_key may only contain letters, digits, hyphens, underscores")
|
|
|
|
|
|
@router.get("")
|
|
async def list_inputs(
|
|
_: dict = Depends(require_auth),
|
|
persons: list[str] = Query(default=[], alias="persons", description="Filter by person name/alias substring (repeatable)"),
|
|
):
|
|
"""List all input images (Discord + web uploads). Optionally filter by persons."""
|
|
active_persons = [p.strip() for p in persons if p.strip()]
|
|
if active_persons:
|
|
import face_db as face_db_mod
|
|
from input_image_db import get_images_by_ids
|
|
all_ids: set[int] = set()
|
|
for p in active_persons:
|
|
ids = face_db_mod.get_source_ids_for_person_query(p, "input")
|
|
all_ids.update(ids)
|
|
images = list(get_images_by_ids(list(all_ids))) if all_ids else []
|
|
if images:
|
|
person_map = face_db_mod.get_persons_for_source_id_map(
|
|
[img["id"] for img in images], "input"
|
|
)
|
|
for img in images:
|
|
img["detected_persons"] = person_map.get(img["id"], [])
|
|
return images
|
|
from input_image_db import get_all_images
|
|
rows = get_all_images()
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
@router.post("")
|
|
async def upload_input(
|
|
file: UploadFile = File(...),
|
|
slot_key: Optional[str] = Form(default=None),
|
|
user: dict = Depends(require_auth),
|
|
):
|
|
"""
|
|
Upload an input image.
|
|
|
|
Stores image bytes directly in SQLite. If *slot_key* is provided the
|
|
image is immediately activated for that slot (writes to ComfyUI input
|
|
folder and updates the user's state override).
|
|
|
|
The physical slot file uses a namespaced key ``<user_label>_<slot_key>``
|
|
so concurrent users each get their own active image file.
|
|
"""
|
|
config = get_config()
|
|
if config is None:
|
|
raise HTTPException(503, "Config not available")
|
|
|
|
if slot_key:
|
|
_validate_slot_key(slot_key)
|
|
|
|
data = await file.read()
|
|
filename = file.filename or "upload.png"
|
|
ext = Path(filename).suffix.lower()
|
|
if ext not in _ALLOWED_EXTS:
|
|
raise HTTPException(415, f"Unsupported file type '{ext}'. Allowed: {sorted(_ALLOWED_EXTS)}")
|
|
if len(data) > _MAX_UPLOAD_BYTES:
|
|
raise HTTPException(413, "File too large (max 50 MB)")
|
|
|
|
from input_image_db import upsert_image, activate_image_for_slot
|
|
row_id = upsert_image(
|
|
original_message_id=0, # sentinel for web uploads
|
|
bot_reply_id=0,
|
|
channel_id=0,
|
|
filename=filename,
|
|
image_data=data,
|
|
)
|
|
|
|
activated_filename: str | None = None
|
|
if slot_key:
|
|
user_label: str = user["sub"]
|
|
namespaced_key = f"{user_label}_{slot_key}"
|
|
activated_filename = activate_image_for_slot(
|
|
row_id, namespaced_key, config.comfy_input_path
|
|
)
|
|
registry = get_user_registry()
|
|
if registry:
|
|
registry.get_state_manager(user_label).set_override(slot_key, activated_filename)
|
|
else:
|
|
from web.deps import get_comfy
|
|
comfy = get_comfy()
|
|
if comfy:
|
|
comfy.state_manager.set_override(slot_key, activated_filename)
|
|
|
|
# Face scan — runs synchronously here (~1-2 s); unknown faces returned to UI
|
|
pending_faces: list[dict] = []
|
|
try:
|
|
from face_service import get_face_service
|
|
import face_db as _face_db
|
|
_face_db.init_db()
|
|
svc = get_face_service()
|
|
if svc.available:
|
|
results = await svc.scan_input_image(row_id, data)
|
|
pending_faces = [
|
|
{"detection_id": r.detection_id, "face_index": r.face_index, "bbox": r.bbox}
|
|
for r in results
|
|
if r.matched_person_id is None
|
|
]
|
|
except Exception as exc:
|
|
logger.warning("Face scan failed for upload row_id=%d: %s", row_id, exc)
|
|
|
|
return {
|
|
"id": row_id,
|
|
"filename": filename,
|
|
"slot_key": slot_key,
|
|
"activated_filename": activated_filename,
|
|
"pending_faces": pending_faces,
|
|
}
|
|
|
|
|
|
@router.post("/{row_id}/activate")
|
|
async def activate_input(
|
|
row_id: int,
|
|
slot_key: str = Body(default="input_image", embed=True),
|
|
user: dict = Depends(require_auth),
|
|
):
|
|
"""Write the stored image to the ComfyUI input folder and set the user's slot override."""
|
|
config = get_config()
|
|
if config is None:
|
|
raise HTTPException(503, "Config not available")
|
|
|
|
from input_image_db import get_image, activate_image_for_slot
|
|
row = get_image(row_id)
|
|
if row is None:
|
|
raise HTTPException(404, "Image not found")
|
|
|
|
user_label: str = user["sub"]
|
|
_validate_slot_key(slot_key)
|
|
namespaced_key = f"{user_label}_{slot_key}"
|
|
|
|
try:
|
|
filename = activate_image_for_slot(row_id, namespaced_key, config.comfy_input_path)
|
|
except ValueError as exc:
|
|
raise HTTPException(409, str(exc))
|
|
|
|
registry = get_user_registry()
|
|
if registry:
|
|
registry.get_state_manager(user_label).set_override(slot_key, filename)
|
|
else:
|
|
from web.deps import get_comfy
|
|
comfy = get_comfy()
|
|
if comfy is None:
|
|
raise HTTPException(503, "State manager not available")
|
|
comfy.state_manager.set_override(slot_key, filename)
|
|
|
|
return {"ok": True, "slot_key": slot_key, "filename": filename}
|
|
|
|
|
|
@router.delete("/{row_id}")
|
|
async def delete_input(row_id: int, _: dict = Depends(require_auth)):
|
|
"""Delete an input image record (and its active slot file if applicable)."""
|
|
from input_image_db import get_image, delete_image
|
|
row = get_image(row_id)
|
|
if row is None:
|
|
raise HTTPException(404, "Image not found")
|
|
|
|
config = get_config()
|
|
delete_image(row_id, comfy_input_path=config.comfy_input_path if config else None)
|
|
return {"ok": True}
|
|
|
|
|
|
@router.get("/{row_id}/image")
|
|
async def get_input_image(row_id: int, _: dict = Depends(require_auth)):
|
|
"""Serve the raw image bytes stored in the database for a given input image row."""
|
|
from input_image_db import get_image, get_image_data
|
|
row = get_image(row_id)
|
|
if row is None:
|
|
raise HTTPException(404, "Image not found")
|
|
|
|
data = get_image_data(row_id)
|
|
if data is None:
|
|
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
|
|
|
mime, _ = mimetypes.guess_type(row["filename"])
|
|
return Response(content=data, media_type=mime or "application/octet-stream")
|
|
|
|
|
|
def _pil_resize_response(data: bytes, filename: str, max_size: int, quality: int) -> Response:
|
|
"""Resize image bytes with Pillow and return a JPEG Response. Raises on failure."""
|
|
import io
|
|
from PIL import Image as _PIL
|
|
img = _PIL.open(io.BytesIO(data))
|
|
img.thumbnail((max_size, max_size), _PIL.LANCZOS)
|
|
buf = io.BytesIO()
|
|
img.convert("RGB").save(buf, "JPEG", quality=quality, optimize=True)
|
|
return Response(
|
|
content=buf.getvalue(),
|
|
media_type="image/jpeg",
|
|
headers={"Cache-Control": "public, max-age=86400"},
|
|
)
|
|
|
|
|
|
@router.get("/{row_id}/thumb")
|
|
async def get_input_thumb(row_id: int, _: dict = Depends(require_auth)):
|
|
"""Serve a small compressed thumbnail (max 200 px, JPEG 65 %) for fast previews."""
|
|
from input_image_db import get_image, get_image_data
|
|
row = get_image(row_id)
|
|
if row is None:
|
|
raise HTTPException(404, "Image not found")
|
|
|
|
data = get_image_data(row_id)
|
|
if data is None:
|
|
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
|
|
|
try:
|
|
return _pil_resize_response(data, row["filename"], max_size=200, quality=65)
|
|
except Exception:
|
|
mime, _ = mimetypes.guess_type(row["filename"])
|
|
return Response(content=data, media_type=mime or "application/octet-stream")
|
|
|
|
|
|
@router.get("/{row_id}/mid")
|
|
async def get_input_mid(row_id: int, _: dict = Depends(require_auth)):
|
|
"""Serve a medium compressed image (max 800 px, JPEG 80 %) for progressive loading."""
|
|
from input_image_db import get_image, get_image_data
|
|
row = get_image(row_id)
|
|
if row is None:
|
|
raise HTTPException(404, "Image not found")
|
|
|
|
data = get_image_data(row_id)
|
|
if data is None:
|
|
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
|
|
|
try:
|
|
return _pil_resize_response(data, row["filename"], max_size=800, quality=80)
|
|
except Exception:
|
|
mime, _ = mimetypes.guess_type(row["filename"])
|
|
return Response(content=data, media_type=mime or "application/octet-stream")
|