-
Current workflow
+
+
+ Current workflow
+
{wf?.loaded ? (
<>
-
{wf.last_workflow_file ?? '(loaded from state)'}
-
{wf.node_count} node(s) detected
+
{wf.last_workflow_file ?? '(loaded from state)'}
+
{wf.node_count} node(s) detected
>
) : (
No workflow loaded
@@ -76,12 +82,12 @@ export default function WorkflowPage() {
{message && (
-
{message}
+
{message}
)}
{/* Upload */}
-
{/* Available files */}
{(filesData?.files ?? []).length > 0 && (
-
Available workflows
-
- {(filesData?.files ?? []).map(f => (
- -
-
+
+ Available workflows
+
+
+ {(filesData?.files ?? []).map((f, idx) => (
+
+
{f}
- {wf?.last_workflow_file === f && (active)}
+ {wf?.last_workflow_file === f && (
+ (active)
+ )}
-
+
))}
-
+
)}
{/* Discovered inputs summary */}
{inputs && allInputs.length > 0 && (
-
Discovered inputs ({allInputs.length})
-
-
-
-
- | Key |
- Label |
- Type |
- Common |
-
-
-
- {allInputs.map(inp => (
-
- | {inp.key} |
- {inp.label} |
- {inp.input_type} |
- {inp.is_common ? '✓' : ''} |
+
+ Discovered inputs ({allInputs.length})
+
+
+
+
+
+ | Key |
+ Label |
+ Type |
+ Common |
- ))}
-
-
+
+
+ {allInputs.map((inp, idx) => (
+
+ | {inp.key} |
+ {inp.label} |
+ {inp.input_type} |
+ {inp.is_common ? '✓' : ''} |
+
+ ))}
+
+
)}
diff --git a/frontend/tailwind.config.js b/frontend/tailwind.config.js
index 4d76c5e..645242e 100644
--- a/frontend/tailwind.config.js
+++ b/frontend/tailwind.config.js
@@ -3,7 +3,22 @@ export default {
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
darkMode: 'class',
theme: {
- extend: {},
+ extend: {
+ animation: {
+ 'fade-in': 'fade-in 0.25s ease-out',
+ 'slide-up': 'slide-up 0.3s ease-out',
+ },
+ keyframes: {
+ 'fade-in': {
+ from: { opacity: '0', transform: 'translateY(8px)' },
+ to: { opacity: '1', transform: 'translateY(0)' },
+ },
+ 'slide-up': {
+ from: { opacity: '0', transform: 'translateY(16px)' },
+ to: { opacity: '1', transform: 'translateY(0)' },
+ },
+ },
+ },
},
plugins: [],
}
diff --git a/generation_db.py b/generation_db.py
index dd21b25..d1af8fa 100644
--- a/generation_db.py
+++ b/generation_db.py
@@ -48,11 +48,30 @@ CREATE TABLE IF NOT EXISTS generation_shares (
share_token TEXT UNIQUE NOT NULL,
prompt_id TEXT NOT NULL,
owner_label TEXT NOT NULL,
- created_at TEXT NOT NULL
+ created_at TEXT NOT NULL,
+ is_public INTEGER NOT NULL DEFAULT 0,
+ expires_at TEXT,
+ max_views INTEGER,
+ view_count INTEGER NOT NULL DEFAULT 0
);
"""
+def _migrate_shares_table(conn: sqlite3.Connection) -> None:
+ migrations = [
+ "ALTER TABLE generation_shares ADD COLUMN is_public INTEGER NOT NULL DEFAULT 0",
+ "ALTER TABLE generation_shares ADD COLUMN expires_at TEXT",
+ "ALTER TABLE generation_shares ADD COLUMN max_views INTEGER",
+ "ALTER TABLE generation_shares ADD COLUMN view_count INTEGER NOT NULL DEFAULT 0",
+ ]
+ for sql in migrations:
+ try:
+ conn.execute(sql)
+ except sqlite3.OperationalError:
+ pass # column already exists
+ conn.commit()
+
+
def _connect(db_path: Path | None = None) -> sqlite3.Connection:
path = db_path if db_path is not None else _DB_PATH
conn = sqlite3.connect(str(path), check_same_thread=False)
@@ -84,6 +103,7 @@ def init_db(db_path: Path = _DB_PATH) -> None:
with _connect(db_path) as conn:
conn.executescript(_SCHEMA)
conn.commit()
+ _migrate_shares_table(conn)
def record_generation(
@@ -109,11 +129,11 @@ def record_generation(
return cur.lastrowid # type: ignore[return-value]
-def record_file(generation_id: int, filename: str, file_data: bytes) -> None:
- """Insert a file BLOB row, auto-detecting MIME type from magic bytes."""
+def record_file(generation_id: int, filename: str, file_data: bytes) -> int:
+ """Insert a file BLOB row, auto-detecting MIME type from magic bytes. Returns the row id."""
mime_type = _detect_mime(file_data)
with _connect() as conn:
- conn.execute(
+ cur = conn.execute(
"""
INSERT INTO generation_files (generation_id, filename, file_data, mime_type)
VALUES (?, ?, ?, ?)
@@ -121,6 +141,7 @@ def record_file(generation_id: int, filename: str, file_data: bytes) -> None:
(generation_id, filename, file_data, mime_type),
)
conn.commit()
+ return cur.lastrowid # type: ignore[return-value]
def _rows_to_history(conn: sqlite3.Connection, rows) -> list[dict]:
@@ -151,7 +172,9 @@ def get_history(limit: int = 50) -> list[dict]:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
- s.share_token
+ s.share_token, s.is_public AS share_is_public,
+ s.expires_at AS share_expires_at, s.max_views AS share_max_views,
+ s.view_count AS share_view_count
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
ORDER BY h.id DESC LIMIT ?
@@ -167,7 +190,9 @@ def get_history_for_user(user_label: str, limit: int = 50) -> list[dict]:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
- s.share_token
+ s.share_token, s.is_public AS share_is_public,
+ s.expires_at AS share_expires_at, s.max_views AS share_max_views,
+ s.view_count AS share_view_count
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
WHERE h.user_label = ?
@@ -214,7 +239,9 @@ def search_history_for_user(user_label: str, query: str, limit: int = 50) -> lis
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
- s.share_token
+ s.share_token, s.is_public AS share_is_public,
+ s.expires_at AS share_expires_at, s.max_views AS share_max_views,
+ s.view_count AS share_view_count
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
WHERE h.user_label = ? AND LOWER(h.overrides) LIKE LOWER(?)
@@ -231,7 +258,9 @@ def search_history(query: str, limit: int = 50) -> list[dict]:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
- s.share_token
+ s.share_token, s.is_public AS share_is_public,
+ s.expires_at AS share_expires_at, s.max_views AS share_max_views,
+ s.view_count AS share_view_count
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
WHERE LOWER(h.overrides) LIKE LOWER(?)
@@ -242,43 +271,111 @@ def search_history(query: str, limit: int = 50) -> list[dict]:
return _rows_to_history(conn, rows)
-def create_share(prompt_id: str, owner_label: str) -> str:
- """Create a share token for *prompt_id*. Idempotent — returns the same token if one exists."""
+def _is_share_expired(share_row: dict) -> bool:
+ """Return True if the share has passed its time or view limits."""
+ if share_row["expires_at"] and datetime.fromisoformat(share_row["expires_at"]) <= datetime.now(timezone.utc):
+ return True
+ if share_row["max_views"] is not None and share_row["view_count"] >= share_row["max_views"]:
+ return True
+ return False
+
+
+def _is_share_streaming_expired(share_row: dict) -> bool:
+ """Expiry check for file-streaming calls (no view_count increment).
+ Uses strict-greater-than so that files remain accessible within the
+ same page view that just consumed the last allowed view."""
+ if share_row["expires_at"] and datetime.fromisoformat(share_row["expires_at"]) <= datetime.now(timezone.utc):
+ return True
+ if share_row["max_views"] is not None and share_row["view_count"] > share_row["max_views"]:
+ return True
+ return False
+
+
+def get_active_share_for_prompt(prompt_id: str, owner_label: str) -> dict | None:
+ """Return the active (non-expired) share row for *prompt_id*+*owner_label*, or None.
+
+ If a row exists but is already expired, it is auto-deleted and None is returned
+ so the caller can create a new share immediately.
+ """
+ with _connect() as conn:
+ row = conn.execute(
+ """
+ SELECT share_token, is_public, expires_at, max_views, view_count
+ FROM generation_shares
+ WHERE prompt_id = ? AND owner_label = ?
+ """,
+ (prompt_id, owner_label),
+ ).fetchone()
+ if row is None:
+ return None
+ d = dict(row)
+ if _is_share_expired(d):
+ conn.execute(
+ "DELETE FROM generation_shares WHERE share_token = ?",
+ (d["share_token"],),
+ )
+ conn.commit()
+ return None
+ return d
+
+
+def create_share(
+ prompt_id: str,
+ owner_label: str,
+ *,
+ is_public: bool = False,
+ expires_at: str | None = None,
+ max_views: int | None = None,
+) -> dict:
+ """Insert a fresh share row. Returns dict with share_token, is_public, expires_at, max_views."""
token = secrets.token_urlsafe(32)
created_at = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
conn.execute(
"""
- INSERT OR IGNORE INTO generation_shares (share_token, prompt_id, owner_label, created_at)
- VALUES (?, ?, ?, ?)
+ INSERT INTO generation_shares
+ (share_token, prompt_id, owner_label, created_at, is_public, expires_at, max_views, view_count)
+ VALUES (?, ?, ?, ?, ?, ?, ?, 0)
""",
- (token, prompt_id, owner_label, created_at),
+ (token, prompt_id, owner_label, created_at, int(is_public), expires_at, max_views),
)
conn.commit()
row = conn.execute(
- "SELECT share_token FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
- (prompt_id, owner_label),
+ "SELECT share_token, is_public, expires_at, max_views FROM generation_shares WHERE share_token = ?",
+ (token,),
).fetchone()
- return row["share_token"]
+ return dict(row)
-def revoke_share(prompt_id: str, owner_label: str) -> bool:
- """Delete the share token for *prompt_id*. Returns True if a row was deleted."""
+def revoke_share(prompt_id: str, owner_label: str | None = None) -> bool:
+ """Delete the share token for *prompt_id*.
+
+ If *owner_label* is provided, only delete that user's share.
+ If None (admin), delete any share for the prompt_id.
+ Returns True if a row was deleted.
+ """
with _connect() as conn:
- cur = conn.execute(
- "DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
- (prompt_id, owner_label),
- )
+ if owner_label is not None:
+ cur = conn.execute(
+ "DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
+ (prompt_id, owner_label),
+ )
+ else:
+ cur = conn.execute(
+ "DELETE FROM generation_shares WHERE prompt_id = ?",
+ (prompt_id,),
+ )
conn.commit()
return cur.rowcount > 0
def get_share_by_token(token: str) -> dict | None:
- """Return generation info for a share token, or None if not found/revoked."""
+ """Return generation info for a share token (incrementing view_count), or None if not found/expired."""
with _connect() as conn:
row = conn.execute(
"""
- SELECT h.prompt_id, h.overrides, h.seed, h.created_at
+ SELECT s.share_token, s.is_public, s.expires_at, s.max_views, s.view_count,
+ h.prompt_id, h.overrides, h.seed, h.created_at
FROM generation_shares s
JOIN generation_history h ON h.prompt_id = s.prompt_id
WHERE s.share_token = ?
@@ -288,6 +385,16 @@ def get_share_by_token(token: str) -> dict | None:
if row is None:
return None
d = dict(row)
+ if _is_share_expired(d):
+ conn.execute("DELETE FROM generation_shares WHERE share_token = ?", (token,))
+ conn.commit()
+ return None
+ # Increment view count
+ conn.execute(
+ "UPDATE generation_shares SET view_count = view_count + 1 WHERE share_token = ?",
+ (token,),
+ )
+ conn.commit()
if d["overrides"]:
try:
d["overrides"] = json.loads(d["overrides"])
@@ -298,6 +405,77 @@ def get_share_by_token(token: str) -> dict | None:
return d
+def get_share_meta(token: str) -> dict | None:
+ """Return share metadata without incrementing view_count. Used by file-streaming endpoints."""
+ with _connect() as conn:
+ row = conn.execute(
+ """
+ SELECT s.share_token, s.is_public, s.expires_at, s.max_views, s.view_count,
+ h.prompt_id, h.overrides, h.seed, h.created_at
+ FROM generation_shares s
+ JOIN generation_history h ON h.prompt_id = s.prompt_id
+ WHERE s.share_token = ?
+ """,
+ (token,),
+ ).fetchone()
+ if row is None:
+ return None
+ d = dict(row)
+ if _is_share_streaming_expired(d):
+ conn.execute("DELETE FROM generation_shares WHERE share_token = ?", (token,))
+ conn.commit()
+ return None
+ if d["overrides"]:
+ try:
+ d["overrides"] = json.loads(d["overrides"])
+ except (json.JSONDecodeError, TypeError):
+ d["overrides"] = {}
+ else:
+ d["overrides"] = {}
+ return d
+
+
+def get_file_ids_for_prompt(prompt_id: str) -> list[int]:
+ """Return generation_files.id values for all files belonging to prompt_id."""
+ with _connect() as conn:
+ rows = conn.execute(
+ """SELECT gf.id FROM generation_files gf
+ JOIN generation_history gh ON gh.id = gf.generation_id
+ WHERE gh.prompt_id = ?""",
+ (prompt_id,),
+ ).fetchall()
+ return [r["id"] for r in rows]
+
+
+def get_generation_ids_for_file_ids(file_ids: list[int]) -> list[int]:
+ """Return distinct generation_id values for the given generation_files row ids."""
+ if not file_ids:
+ return []
+ placeholders = ",".join("?" * len(file_ids))
+ with _connect() as conn:
+ rows = conn.execute(
+ f"SELECT DISTINCT generation_id FROM generation_files WHERE id IN ({placeholders})",
+ tuple(file_ids),
+ ).fetchall()
+ return [r["generation_id"] for r in rows]
+
+
+def get_file_ids_for_generation_ids(gen_ids: list[int]) -> dict[int, list[int]]:
+ """Return {gen_id: [file_id, …]} for the given generation_history row ids."""
+ if not gen_ids:
+ return {}
+ placeholders = ",".join("?" * len(gen_ids))
+ with _connect() as conn:
+ rows = conn.execute(
+ f"SELECT generation_id, id FROM generation_files WHERE generation_id IN ({placeholders})",
+ tuple(gen_ids),
+ ).fetchall()
+ result: dict[int, list[int]] = {gid: [] for gid in gen_ids}
+ for row in rows:
+ result[row["generation_id"]].append(row["id"])
+ return result
+
+
def get_files(prompt_id: str) -> list[dict]:
"""Return all output files for *prompt_id* as ``[{filename, data, mime_type}]``."""
with _connect() as conn:
diff --git a/input_image_db.py b/input_image_db.py
index a0417db..147882f 100644
--- a/input_image_db.py
+++ b/input_image_db.py
@@ -213,6 +213,19 @@ def get_all_images() -> list[dict]:
return [dict(r) for r in rows]
+def get_images_by_ids(ids: list[int]) -> list[dict]:
+ """Return image rows for the given ids (excluding image_data), ordered by id DESC."""
+ if not ids:
+ return []
+ placeholders = ",".join("?" * len(ids))
+ with _connect() as conn:
+ rows = conn.execute(
+ f"SELECT {_SAFE_COLS} FROM input_images WHERE id IN ({placeholders}) ORDER BY id DESC",
+ tuple(ids),
+ ).fetchall()
+ return [dict(r) for r in rows]
+
+
def delete_image(row_id: int, comfy_input_path: str | None = None) -> None:
"""
Remove an image record from the database.
diff --git a/media_uploader.py b/media_uploader.py
index df8b054..9c548f4 100644
--- a/media_uploader.py
+++ b/media_uploader.py
@@ -281,6 +281,6 @@ async def flush_pending(
failed,
)
elif uploaded:
- logger.info("Auto-upload complete: %d file(s) uploaded and deleted.", uploaded)
+ logger.info("Auto-upload complete: %d file(s) uploaded.", uploaded)
return uploaded
diff --git a/sync_faces.py b/sync_faces.py
new file mode 100644
index 0000000..2a40edc
--- /dev/null
+++ b/sync_faces.py
@@ -0,0 +1,131 @@
+#!/usr/bin/env python3
+"""
+sync_faces.py
+=============
+
+One-time backfill script: scan existing input_images and generation_files
+for faces and store detections in faces.db.
+
+Usage:
+ python sync_faces.py [--dry-run] [--input-only] [--output-only]
+"""
+
+from __future__ import annotations
+
+import argparse
+import asyncio
+import logging
+import sqlite3
+
+logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
+logger = logging.getLogger(__name__)
+
+
+async def main(
+ dry_run: bool, input_only: bool, output_only: bool,
+ cluster: bool, cluster_threshold: float,
+) -> None:
+ import face_db
+ from face_service import get_face_service
+
+ face_db.init_db()
+ svc = get_face_service()
+
+ if not svc.available:
+ logger.error(
+ "insightface is not available. "
+ "Install: pip install insightface onnxruntime opencv-python"
+ )
+ return
+
+ import generation_db
+ import input_image_db
+
+ total_faces = 0
+ total_matched = 0
+ total_unidentified = 0
+
+ # Scan input images
+ if not output_only:
+ logger.info("Scanning input images…")
+ conn = sqlite3.connect(str(input_image_db.DB_PATH), check_same_thread=False)
+ conn.row_factory = sqlite3.Row
+ rows = conn.execute(
+ "SELECT id, image_data FROM input_images WHERE image_data IS NOT NULL"
+ ).fetchall()
+ conn.close()
+
+ for row in rows:
+ row_id = row["id"]
+ image_bytes = bytes(row["image_data"])
+ logger.info(" input image id=%d (%d bytes)", row_id, len(image_bytes))
+ if not dry_run:
+ try:
+ results = await svc.scan_input_image(row_id, image_bytes)
+ for r in results:
+ total_faces += 1
+ if r.matched_person_id is not None:
+ total_matched += 1
+ else:
+ total_unidentified += 1
+ except Exception as exc:
+ logger.warning(" Failed for input id=%d: %s", row_id, exc)
+
+ # Scan generated output files
+ if not input_only:
+ logger.info("Scanning generation output files…")
+ conn = sqlite3.connect(str(generation_db._DB_PATH), check_same_thread=False)
+ conn.row_factory = sqlite3.Row
+ rows = conn.execute(
+ "SELECT id, file_data, mime_type FROM generation_files"
+ ).fetchall()
+ conn.close()
+
+ for row in rows:
+ file_id = row["id"]
+ file_data = bytes(row["file_data"])
+ mime_type = row["mime_type"] or ""
+ logger.info(
+ " output file id=%d mime=%s (%d bytes)", file_id, mime_type, len(file_data)
+ )
+ if not dry_run:
+ try:
+ if mime_type.startswith("image/"):
+ results = await svc.scan_output_image(file_id, file_data)
+ total_faces += len(results)
+ total_matched += sum(1 for r in results if r.matched_person_id is not None)
+ elif mime_type.startswith("video/"):
+ results = await svc.scan_video(file_id, file_data)
+ total_faces += len(results)
+ total_matched += sum(1 for r in results if r.matched_person_id is not None)
+ except Exception as exc:
+ logger.warning(" Failed for output id=%d: %s", file_id, exc)
+
+ if dry_run:
+ logger.info("Dry run — no data written.")
+ else:
+ logger.info(
+ "Done. %d faces detected, %d matched to known persons, %d unidentified",
+ total_faces,
+ total_matched,
+ total_unidentified,
+ )
+ if cluster:
+ logger.info("Clustering unidentified faces (threshold=%.2f)…", cluster_threshold)
+ groups = await svc.cluster_unidentified_faces(cluster_threshold)
+ logger.info("Clustering: %d groups created", len(groups))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Backfill face detections for existing media")
+ parser.add_argument("--dry-run", action="store_true")
+ parser.add_argument("--input-only", action="store_true")
+ parser.add_argument("--output-only", action="store_true")
+ parser.add_argument("--cluster", action="store_true", help="Run auto-clustering after scanning")
+ parser.add_argument("--cluster-threshold", type=float, default=0.45, metavar="T",
+ help="Cosine similarity threshold for clustering (default: 0.45)")
+ args = parser.parse_args()
+ asyncio.run(main(
+ args.dry_run, args.input_only, args.output_only,
+ args.cluster, args.cluster_threshold,
+ ))
diff --git a/web/app.py b/web/app.py
index bdae264..5fd8d4b 100644
--- a/web/app.py
+++ b/web/app.py
@@ -178,6 +178,7 @@ def create_app() -> FastAPI:
from web.routers.share_router import router as share_router
from web.routers.workflow_router import router as workflow_router
from web.routers.ws_router import router as ws_router
+ from web.routers.faces_router import router as faces_router
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
@@ -191,6 +192,7 @@ def create_app() -> FastAPI:
app.include_router(share_router, prefix="/api/share", tags=["share"])
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
app.include_router(ws_router, tags=["ws"])
+ app.include_router(faces_router, prefix="/api/faces", tags=["faces"])
from web.routers.ws_router import websocket_endpoint as _ws_endpoint
@@ -212,6 +214,11 @@ def create_app() -> FastAPI:
@app.on_event("startup")
async def _startup():
+ try:
+ import face_db
+ face_db.init_db()
+ except Exception as _exc:
+ logger.warning("face_db.init_db() failed (non-fatal): %s", _exc)
asyncio.create_task(_status_ticker())
asyncio.create_task(_server_state_poller())
logger.info("Web background tasks started")
@@ -224,12 +231,12 @@ def create_app() -> FastAPI:
# ---------------------------------------------------------------------------
async def _status_ticker() -> None:
- """Broadcast status_snapshot to all clients every 5 seconds."""
+ """Broadcast status_snapshot to all clients every 2 seconds."""
from web.deps import get_bot, get_comfy, get_config
bus = get_bus()
while True:
- await asyncio.sleep(5)
+ await asyncio.sleep(2)
try:
bot = get_bot()
comfy = get_comfy()
diff --git a/web/auth.py b/web/auth.py
index e61b1a1..533f4e2 100644
--- a/web/auth.py
+++ b/web/auth.py
@@ -37,12 +37,17 @@ _COOKIE_NAME = "ttb_session"
def _get_secret() -> str:
from web.deps import get_config
cfg = get_config()
- if cfg and cfg.web_secret_key:
- return cfg.web_secret_key
- raise RuntimeError(
- "WEB_SECRET_KEY must be set in the environment — "
- "refusing to run with an insecure default."
- )
+ key = cfg.web_secret_key if cfg else ""
+ if not key:
+ raise RuntimeError(
+ "WEB_SECRET_KEY must be set in the environment — "
+ "refusing to run with an insecure default."
+ )
+ if len(key) < 32:
+ raise RuntimeError(
+ "WEB_SECRET_KEY is too short (got %d chars, need ≥ 32)." % len(key)
+ )
+ return key
def create_jwt(label: str, *, admin: bool = False, expire_hours: int = 8) -> str:
@@ -102,6 +107,13 @@ def require_auth(ttb_session: Optional[str] = Cookie(default=None)) -> dict:
return payload
+def optional_auth(ttb_session: Optional[str] = Cookie(default=None)) -> Optional[dict]:
+ """Returns decoded JWT payload or None. Never raises 401."""
+ if not ttb_session:
+ return None
+ return decode_jwt(ttb_session)
+
+
def require_admin(user: dict = Depends(require_auth)) -> dict:
"""
FastAPI dependency that requires an admin JWT.
diff --git a/web/routers/faces_router.py b/web/routers/faces_router.py
new file mode 100644
index 0000000..7183d2e
--- /dev/null
+++ b/web/routers/faces_router.py
@@ -0,0 +1,451 @@
+"""GET/POST /api/faces/* — face recognition endpoints."""
+from __future__ import annotations
+
+import logging
+
+import numpy as np
+from fastapi import APIRouter, Depends, HTTPException, Query
+from fastapi.responses import Response
+from pydantic import BaseModel
+
+from face_service import _SIMILARITY_THRESHOLD
+from web.auth import require_admin
+
+router = APIRouter()
+logger = logging.getLogger(__name__)
+
+
+def _auto_link_for_person(person_id: int, face_db_module, ungrouped: list[dict]) -> int:
+ """Auto-link ungrouped unidentified detections similar to person_id. Returns count linked."""
+ ref_embeddings = face_db_module.get_person_embeddings(person_id)
+ if not ref_embeddings or not ungrouped:
+ return 0
+ ref_M = np.stack(ref_embeddings)
+ ref_M_norm = ref_M / (np.linalg.norm(ref_M, axis=1, keepdims=True) + 1e-8)
+ count = 0
+ for ue in ungrouped:
+ norm_ue = ue["embedding"] / (np.linalg.norm(ue["embedding"]) + 1e-8)
+ if float(np.max(ref_M_norm @ norm_ue)) >= _SIMILARITY_THRESHOLD:
+ face_db_module.link_detection_to_person(ue["id"], person_id)
+ count += 1
+ return count
+
+
+@router.get("/persons")
+async def list_persons(_: dict = Depends(require_admin)):
+ """List all known persons."""
+ import face_db
+ return {"persons": face_db.list_persons()}
+
+
+@router.get("/persons/check")
+async def check_person_name(name: str, _: dict = Depends(require_admin)):
+ """Check whether a person name is already taken (case-insensitive)."""
+ import face_db
+ return {"exists": face_db.person_name_exists(name)}
+
+
+@router.get("/persons/search")
+async def search_persons(
+ q: str = "",
+ limit: int = Query(10, ge=1, le=50),
+ _: dict = Depends(require_admin),
+):
+ """Search persons by name/alias substring."""
+ import face_db
+ all_persons = face_db.list_persons()
+ q_lower = q.strip().lower()
+ if q_lower:
+ filtered = [
+ p for p in all_persons
+ if q_lower in p["name"].lower()
+ or any(q_lower in a["alias"].lower() for a in p["aliases"])
+ ]
+ else:
+ filtered = all_persons
+ return {"persons": filtered[:limit]}
+
+
+@router.get("/persons/{person_id}")
+async def get_person(person_id: int, _: dict = Depends(require_admin)):
+ """Get a single person with aliases."""
+ import face_db
+ person = face_db.get_person(person_id)
+ if person is None:
+ raise HTTPException(404, f"Person {person_id} not found")
+ return person
+
+
+@router.get("/crop/{detection_id}")
+async def get_face_crop(detection_id: int, _: dict = Depends(require_admin)):
+ """Return a JPEG face crop for the given detection id."""
+ import asyncio
+ from face_service import get_face_service
+ svc = get_face_service()
+ if not svc.available:
+ raise HTTPException(503, "Face service not available")
+ loop = asyncio.get_event_loop()
+ crop = await loop.run_in_executor(svc._executor, svc.get_face_crop, detection_id)
+ if crop is None:
+ raise HTTPException(404, "Face crop not found")
+ return Response(content=crop, media_type="image/jpeg")
+
+
+class _IdentifyItem(BaseModel):
+ detection_id: int
+ name: str
+ use_existing: bool = False
+
+
+class _IdentifyRequest(BaseModel):
+ identifications: list[_IdentifyItem]
+
+
+@router.get("/detections/unidentified")
+async def list_unidentified_detections(
+ limit: int = Query(50, ge=1, le=200),
+ offset: int = Query(0, ge=0),
+ _: dict = Depends(require_admin),
+):
+ """List unidentified face detections from input images (paginated)."""
+ import face_db
+ detections, total = face_db.get_unidentified_input_detections(limit=limit, offset=offset)
+ return {"detections": detections, "total": total}
+
+
+class _AliasRequest(BaseModel):
+ alias: str
+
+
+@router.post("/persons/{person_id}/aliases")
+async def add_person_alias(
+ person_id: int,
+ body: _AliasRequest,
+ _: dict = Depends(require_admin),
+):
+ """Add an alias to a person."""
+ import face_db
+ alias = body.alias.strip()
+ if not alias:
+ raise HTTPException(400, "Alias cannot be empty")
+ if len(alias) > 100:
+ raise HTTPException(400, "Alias too long (max 100 chars)")
+ try:
+ alias_id, _ = face_db.add_alias(person_id, alias)
+ except ValueError as e:
+ raise HTTPException(409, str(e)) from e
+ return {"id": alias_id, "alias": alias}
+
+
+@router.delete("/persons/{person_id}/aliases/{alias_id}")
+async def remove_person_alias(
+ person_id: int,
+ alias_id: int,
+ _: dict = Depends(require_admin),
+):
+ """Remove an alias from a person."""
+ import face_db
+ face_db.remove_alias(alias_id)
+ return {"ok": True}
+
+
+class _RenameRequest(BaseModel):
+ name: str
+
+
+@router.patch("/persons/{person_id}")
+async def rename_person(
+ person_id: int,
+ body: _RenameRequest,
+ _: dict = Depends(require_admin),
+):
+ """Rename a person."""
+ import face_db
+ name = body.name.strip()
+ if not name:
+ raise HTTPException(400, "Name cannot be empty")
+ if len(name) > 100:
+ raise HTTPException(400, "Name too long (max 100 chars)")
+ if face_db.get_person(person_id) is None:
+ raise HTTPException(404, f"Person {person_id} not found")
+ try:
+ face_db.rename_person(person_id, name)
+ except ValueError as e:
+ raise HTTPException(409, str(e)) from e
+ return {"ok": True, "person_id": person_id, "name": name}
+
+
+@router.delete("/persons/{person_id}")
+async def delete_person(person_id: int, _: dict = Depends(require_admin)):
+ """Delete a person and unidentify all their detections."""
+ import face_db
+ if face_db.get_person(person_id) is None:
+ raise HTTPException(404, f"Person {person_id} not found")
+ face_db.delete_person(person_id)
+ return {"ok": True}
+
+
+@router.get("/persons/{person_id}/detections")
+async def get_person_detections(
+ person_id: int,
+ limit: int = Query(50, ge=1, le=200),
+ offset: int = Query(0, ge=0),
+ _: dict = Depends(require_admin),
+):
+ """List detections for a person (paginated)."""
+ import face_db
+ detections, total = face_db.get_detections_for_person(person_id, limit=limit, offset=offset)
+ return {"detections": detections, "total": total}
+
+
+class _MergePersonRequest(BaseModel):
+ other_person_id: int
+
+
+@router.post("/persons/{person_id}/merge")
+async def merge_persons(
+ person_id: int,
+ body: _MergePersonRequest,
+ _: dict = Depends(require_admin),
+):
+ """Merge another person into this one (survivor keeps their id)."""
+ import face_db
+ if person_id == body.other_person_id:
+ raise HTTPException(400, "Cannot merge person into themselves")
+ if face_db.get_person(person_id) is None:
+ raise HTTPException(404, f"Person {person_id} not found")
+ if face_db.get_person(body.other_person_id) is None:
+ raise HTTPException(404, f"Person {body.other_person_id} not found")
+ face_db.merge_persons(person_id, body.other_person_id)
+ return {"ok": True, "survivor_id": person_id, "absorbed_id": body.other_person_id}
+
+
+class _ReassignRequest(BaseModel):
+ person_name: str | None = None
+ use_existing: bool = False
+
+
+@router.post("/detections/{detection_id}/reassign")
+async def reassign_detection(
+ detection_id: int,
+ body: _ReassignRequest,
+ _: dict = Depends(require_admin),
+):
+ """
+ Reassign or unidentify a detection.
+ - person_name=null → unidentify (set person_id=NULL)
+ - person_name=str → link to that person (create if needed, or use existing)
+ """
+ import face_db
+ det = face_db.get_detection(detection_id)
+ if det is None:
+ raise HTTPException(404, f"Detection {detection_id} not found")
+ if body.person_name is None:
+ face_db.unidentify_detection(detection_id)
+ return {"detection_id": detection_id, "person_id": None, "unidentified": True}
+ name = body.person_name.strip()
+ if not name:
+ raise HTTPException(400, "Name cannot be empty")
+ if len(name) > 100:
+ raise HTTPException(400, "Name too long (max 100 chars)")
+ exists = face_db.person_name_exists(name)
+ if exists and not body.use_existing:
+ raise HTTPException(
+ 409, f"A person named '{name}' already exists. Set use_existing=true to link."
+ )
+ person_id, is_new = face_db.get_or_create_person(name)
+ face_db.link_detection_to_person(detection_id, person_id)
+ return {"detection_id": detection_id, "person_id": person_id, "is_new": is_new}
+
+
+class _ClusterRequest(BaseModel):
+ threshold: float = 0.45
+
+
+class _MergeGroupsRequest(BaseModel):
+ keep_group_id: int
+ discard_group_id: int
+
+
+class _IdentifyGroupRequest(BaseModel):
+ name: str
+ use_existing: bool = False
+
+
+@router.get("/groups")
+async def list_face_groups(_: dict = Depends(require_admin)):
+ """List face groups with ≥ 2 unidentified detections."""
+ import face_db
+ groups = face_db.get_groups_with_detections()
+ return {"groups": groups, "total": len(groups)}
+
+
+@router.get("/groups/{group_id}/detections")
+async def get_face_group_detections(group_id: int, _: dict = Depends(require_admin)):
+ """Return the unidentified detections belonging to a group (fetched on expand)."""
+ import face_db
+ detections = face_db.get_group_detections(group_id)
+ return {"detections": detections}
+
+
+@router.post("/groups/compute")
+async def compute_face_groups(body: _ClusterRequest, _: dict = Depends(require_admin)):
+ """Run full re-cluster of all unidentified faces."""
+ from face_service import get_face_service
+ if not 0.3 <= body.threshold <= 0.7:
+ raise HTTPException(422, "threshold must be between 0.3 and 0.7")
+ svc = get_face_service()
+ if not svc.available:
+ raise HTTPException(503, "Face service not available")
+ groups = await svc.cluster_unidentified_faces(body.threshold)
+ total_detections = sum(len(g) for g in groups)
+ return {
+ "groups_created": len(groups),
+ "total_detections_clustered": total_detections,
+ "threshold": body.threshold,
+ }
+
+
+@router.post("/groups/merge")
+async def merge_face_groups(body: _MergeGroupsRequest, _: dict = Depends(require_admin)):
+ """Merge two groups into one."""
+ import face_db
+ if body.keep_group_id == body.discard_group_id:
+ raise HTTPException(400, "keep_group_id and discard_group_id must differ")
+ groups_index = {g["id"] for g in face_db.get_groups_with_detections()}
+ if body.keep_group_id not in groups_index:
+ raise HTTPException(404, f"Group {body.keep_group_id} not found")
+ if body.discard_group_id not in groups_index:
+ raise HTTPException(404, f"Group {body.discard_group_id} not found")
+ face_db.merge_groups(body.keep_group_id, body.discard_group_id)
+ return {"ok": True, "surviving_group_id": body.keep_group_id}
+
+
+@router.post("/groups/{group_id}/identify")
+async def identify_face_group(
+ group_id: int,
+ body: _IdentifyGroupRequest,
+ _: dict = Depends(require_admin),
+):
+ """
+ Identify all detections in a group as one person.
+ After identification, auto-links any similar ungrouped unidentified detections.
+ Works for both input and output source types.
+ """
+ import face_db
+ name = body.name.strip()
+ if not name:
+ raise HTTPException(400, "Name cannot be empty")
+ if len(name) > 100:
+ raise HTTPException(400, "Name too long (max 100 chars)")
+
+ detections = face_db.get_group_detections(group_id)
+ if not detections:
+ raise HTTPException(404, f"Group {group_id} not found or has no unidentified detections")
+
+ exists = face_db.person_name_exists(name)
+ if exists and not body.use_existing:
+ raise HTTPException(
+ 409, f"A person named '{name}' already exists. Set use_existing=true to link."
+ )
+
+ person_id, is_new = face_db.get_or_create_person(name)
+ identified_count = 0
+ for det in detections:
+ face_db.link_detection_to_person(det["id"], person_id)
+ identified_count += 1
+
+ face_db.delete_group(group_id)
+
+ # Post-identify: auto-link ungrouped unidentified detections similar to this person
+ ungrouped = face_db.get_ungrouped_unidentified_embeddings()
+ auto_linked_count = _auto_link_for_person(person_id, face_db, ungrouped)
+
+ return {
+ "person_id": person_id,
+ "person_name": name,
+ "is_new": is_new,
+ "identified_count": identified_count,
+ "auto_linked_count": auto_linked_count,
+ }
+
+
+@router.delete("/groups/{group_id}/detections/{detection_id}")
+async def remove_group_detection(
+ group_id: int,
+ detection_id: int,
+ _: dict = Depends(require_admin),
+):
+ """Remove a single detection from its group. Cleans up singleton groups."""
+ import face_db
+ det = face_db.get_detection(detection_id)
+ if det is None or det.get("group_id") != group_id:
+ raise HTTPException(404, "Detection not found in this group")
+ face_db.remove_detection_from_group(detection_id)
+ face_db.delete_singleton_groups()
+ return {"ok": True}
+
+
+@router.post("/rescan/outputs")
+async def rescan_output_embeddings(_: dict = Depends(require_admin)):
+ """Re-detect faces in stored output images to rebuild NULL embeddings."""
+ import face_db
+ from face_service import get_face_service
+ svc = get_face_service()
+ if not svc.available:
+ raise HTTPException(503, "Face service not available")
+ source_ids = face_db.get_null_embedding_output_source_ids()
+ total_updated = 0
+ for source_id in source_ids:
+ updated = await svc.rescan_output_embedding(source_id)
+ total_updated += updated
+ return {"processed": len(source_ids), "updated": total_updated}
+
+
+@router.post("/identify")
+async def identify_faces(body: _IdentifyRequest, _: dict = Depends(require_admin)):
+ """
+ Identify one or more face detections by name.
+
+ - If the name is new → creates a new person and links the detection.
+ - If the name exists and ``use_existing=true`` → links to the existing person.
+ - If the name exists and ``use_existing=false`` → HTTP 409.
+ - Only detections with ``source_type='input'`` may be identified via the web UI.
+ """
+ import face_db
+ results = []
+ for item in body.identifications:
+ name = item.name.strip()
+ if not name:
+ raise HTTPException(400, "Name cannot be empty")
+ if len(name) > 100:
+ raise HTTPException(400, "Name too long (max 100 chars)")
+
+ det = face_db.get_detection(item.detection_id)
+ if det is None:
+ raise HTTPException(404, f"Detection {item.detection_id} not found")
+ if det["source_type"] != "input":
+ raise HTTPException(403, "Only input-image detections may be identified via web UI")
+
+ exists = face_db.person_name_exists(name)
+ if exists and not item.use_existing:
+ raise HTTPException(
+ 409,
+ f"A person named '{name}' already exists. Set use_existing=true to link.",
+ )
+
+ person_id, is_new = face_db.get_or_create_person(name)
+ face_db.link_detection_to_person(item.detection_id, person_id)
+
+ results.append({
+ "detection_id": item.detection_id,
+ "person_id": person_id,
+ "person_name": name,
+ "is_new": is_new,
+ })
+
+ # Auto-link similar ungrouped faces for each person identified in this batch
+ ungrouped = face_db.get_ungrouped_unidentified_embeddings()
+ unique_pids = {r["person_id"] for r in results}
+ auto_linked_count = sum(_auto_link_for_person(pid, face_db, ungrouped) for pid in unique_pids)
+ return {"identifications": results, "auto_linked_count": auto_linked_count}
diff --git a/web/routers/generate_router.py b/web/routers/generate_router.py
index b85162a..b1a3a6a 100644
--- a/web/routers/generate_router.py
+++ b/web/routers/generate_router.py
@@ -17,6 +17,32 @@ router = APIRouter()
logger = logging.getLogger(__name__)
+def _materialize_image_slots(
+ overrides: dict, comfy_input_path: str
+) -> tuple[dict, list[str]]:
+ """
+ For each override whose value is an existing ttb_* file, copy it to a
+ unique name so concurrent jobs each have an immutable copy on disk.
+ Returns (updated_overrides, paths_to_delete_after_generation).
+ """
+ import shutil
+ import uuid as _uuid
+ if not comfy_input_path:
+ return overrides, []
+ updated = dict(overrides)
+ cleanup: list[str] = []
+ input_dir = Path(comfy_input_path)
+ for key, val in overrides.items():
+ if isinstance(val, str) and val.startswith("ttb_") and "." in val:
+ src = input_dir / val
+ if src.is_file():
+ unique_name = f"{src.stem}_{_uuid.uuid4().hex[:8]}{src.suffix}"
+ shutil.copy2(src, input_dir / unique_name)
+ updated[key] = unique_name
+ cleanup.append(str(input_dir / unique_name))
+ return updated, cleanup
+
+
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
@@ -105,7 +131,8 @@ async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
from generation_db import record_generation, record_file
gen_id = record_generation(pid, "web", user_label, overrides_for_gen, seed_used)
for i, img_data in enumerate(images):
- record_file(gen_id, f"image_{i:04d}.png", img_data)
+ file_id = record_file(gen_id, f"image_{i:04d}.png", img_data)
+ comfy._schedule_face_scan("image", file_id, img_data)
if config and videos:
for vid in videos:
vsub = vid.get("video_subfolder", "")
@@ -116,7 +143,9 @@ async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
else Path(config.comfy_output_path) / vname
)
try:
- record_file(gen_id, vname, vpath.read_bytes())
+ vid_data = vpath.read_bytes()
+ file_id = record_file(gen_id, vname, vid_data)
+ comfy._schedule_face_scan("video", file_id, vid_data)
except OSError:
pass
except Exception as exc:
@@ -163,25 +192,32 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
registry = get_user_registry()
count = max(1, min(body.count, 20)) # cap at 20
- async def _run_one():
- # Use the user's own state and template
- if registry:
- user_sm = registry.get_state_manager(user_label)
- user_template = registry.get_workflow_template(user_label)
- else:
- user_sm = comfy.state_manager
- user_template = comfy.workflow_manager.get_workflow_template()
+ # --- snapshot state at queue time, not at execution time ---
+ if registry:
+ _user_sm = registry.get_state_manager(user_label)
+ _user_template = registry.get_workflow_template(user_label)
+ else:
+ _user_sm = comfy.state_manager
+ _user_template = comfy.workflow_manager.get_workflow_template()
- if not user_template:
+ base_overrides = _user_sm.get_overrides()
+ if body.overrides:
+ base_overrides = {**base_overrides, **body.overrides}
+
+ _config = get_config()
+
+ async def _run_one(overrides: dict, cleanup_paths: list[str]):
+ if not _user_template:
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": "No workflow template loaded"
})
+ for p in cleanup_paths:
+ try:
+ Path(p).unlink(missing_ok=True)
+ except Exception:
+ pass
return
- overrides = user_sm.get_overrides()
- if body.overrides:
- overrides = {**overrides, **body.overrides}
-
import uuid
pid = str(uuid.uuid4())
@@ -190,7 +226,7 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
"node": node, "prompt_id": pid_
}))
- workflow, applied = comfy.inspector.inject_overrides(user_template, overrides)
+ workflow, applied = comfy.inspector.inject_overrides(_user_template, overrides)
seed_used = applied.get("seed")
comfy.last_seed = seed_used
@@ -201,17 +237,23 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": str(exc)
})
+ for p in cleanup_paths:
+ try:
+ Path(p).unlink(missing_ok=True)
+ except Exception:
+ pass
return
comfy.last_prompt_id = pid
comfy.total_generated += 1
- config = get_config()
+ config = _config
try:
from generation_db import record_generation, record_file
gen_id = record_generation(pid, "web", user_label, overrides, seed_used)
for i, img_data in enumerate(images):
- record_file(gen_id, f"image_{i:04d}.png", img_data)
+ file_id = record_file(gen_id, f"image_{i:04d}.png", img_data)
+ comfy._schedule_face_scan("image", file_id, img_data)
if config and videos:
for vid in videos:
vsub = vid.get("video_subfolder", "")
@@ -222,7 +264,9 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
else Path(config.comfy_output_path) / vname
)
try:
- record_file(gen_id, vname, vpath.read_bytes())
+ vid_data = vpath.read_bytes()
+ file_id = record_file(gen_id, vname, vid_data)
+ comfy._schedule_face_scan("video", file_id, vid_data)
except OSError:
pass
except Exception as exc:
@@ -236,6 +280,13 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
config.media_upload_pass,
))
+ # Clean up unique image copies now that ComfyUI has ingested them
+ for p in cleanup_paths:
+ try:
+ Path(p).unlink(missing_ok=True)
+ except Exception:
+ pass
+
await bus.broadcast("queue_update", {"prompt_id": pid, "status": "complete"})
await bus.broadcast_to_user(user_label, "generation_complete", {
"prompt_id": pid,
@@ -246,7 +297,10 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
depth = await comfy.get_queue_depth()
for _ in range(count):
- asyncio.create_task(_run_one())
+ job_overrides, cleanup = _materialize_image_slots(
+ base_overrides, _config.comfy_input_path if _config else ""
+ )
+ asyncio.create_task(_run_one(job_overrides, cleanup))
return {
"queued": True,
diff --git a/web/routers/history_router.py b/web/routers/history_router.py
index 653f6ca..2a0a261 100644
--- a/web/routers/history_router.py
+++ b/web/routers/history_router.py
@@ -3,12 +3,14 @@ POST /api/history/{prompt_id}/share; DELETE /api/history/{prompt_id}/share"""
from __future__ import annotations
import base64
+from datetime import datetime, timedelta, timezone
from typing import Optional
-from fastapi import APIRouter, Depends, HTTPException, Query, Request
+from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request
from fastapi.responses import Response
+from pydantic import BaseModel
-from web.auth import require_auth
+from web.auth import require_admin, require_auth
router = APIRouter()
@@ -31,22 +33,135 @@ def _assert_owns(prompt_id: str, user: dict) -> None:
async def get_history(
user: dict = Depends(require_auth),
q: Optional[str] = Query(None, description="Keyword to search in overrides JSON"),
+ persons: list[str] = Query(default=[], alias="persons", description="Filter by person name/alias substring (repeatable)"),
):
"""Return generation history. Admins see all; regular users see only their own.
- Pass ?q=keyword to filter by prompt text or any override field."""
+ Pass ?q=keyword to filter by prompt text or any override field.
+ Pass ?persons=name (repeatable) to filter by persons detected in output images."""
+ import face_db as face_db_mod
from generation_db import (
get_history as db_get_history,
get_history_for_user,
search_history,
search_history_for_user,
+ get_generation_ids_for_file_ids,
+ get_file_ids_for_generation_ids,
)
+
+ # Person filter: OR union across all named persons, then filter rows
+ gen_id_filter: list[int] | None = None
+ all_file_ids: set[int] = set()
+ active_persons = [p.strip() for p in persons if p.strip()]
+ if active_persons:
+ for p in active_persons:
+ ids = face_db_mod.get_source_ids_for_person_query(p, "output")
+ all_file_ids.update(ids)
+ if not all_file_ids:
+ return {"history": []}
+ gen_ids = get_generation_ids_for_file_ids(list(all_file_ids))
+ if not gen_ids:
+ return {"history": []}
+ gen_id_filter = gen_ids
+
if q and q.strip():
if user.get("admin"):
- return {"history": search_history(q.strip(), limit=50)}
- return {"history": search_history_for_user(user["sub"], q.strip(), limit=50)}
- if user.get("admin"):
- return {"history": db_get_history(limit=50)}
- return {"history": get_history_for_user(user["sub"], limit=50)}
+ rows = search_history(q.strip(), limit=50)
+ else:
+ rows = search_history_for_user(user["sub"], q.strip(), limit=50)
+ else:
+ if user.get("admin"):
+ rows = db_get_history(limit=50)
+ else:
+ rows = get_history_for_user(user["sub"], limit=50)
+
+ if gen_id_filter is not None:
+ gen_id_set = set(gen_id_filter)
+ rows = [r for r in rows if r["id"] in gen_id_set]
+
+ # Annotate each row with detected_persons when person filtering is active
+ if active_persons and rows:
+ row_ids = [r["id"] for r in rows]
+ file_ids_per_gen = get_file_ids_for_generation_ids(row_ids)
+ all_annotate_ids = [fid for fids in file_ids_per_gen.values() for fid in fids]
+ person_map = face_db_mod.get_persons_for_source_id_map(all_annotate_ids, "output") if all_annotate_ids else {}
+ for row in rows:
+ row["detected_persons"] = list({
+ name
+ for fid in file_ids_per_gen.get(row["id"], [])
+ for name in person_map.get(fid, [])
+ })
+
+ return {"history": rows}
+
+
+@router.get("/{prompt_id}/persons")
+async def get_generation_persons(prompt_id: str, user: dict = Depends(require_auth)):
+ """Return persons detected in the output images of a generation."""
+ _assert_owns(prompt_id, user)
+ import generation_db as gen_db
+ import face_db as face_db_mod
+ file_ids = gen_db.get_file_ids_for_prompt(prompt_id)
+ persons = face_db_mod.get_persons_for_source_ids(file_ids, "output")
+ return {"persons": persons}
+
+
+class _AddPersonRequest(BaseModel):
+ name: str
+
+
+@router.post("/{prompt_id}/persons")
+async def add_generation_person(
+ prompt_id: str,
+ body: _AddPersonRequest,
+ _: dict = Depends(require_admin),
+):
+ """Manually tag a person to a generation's output files."""
+ import generation_db as gen_db
+ import face_db as face_db_mod
+
+ file_ids = gen_db.get_file_ids_for_prompt(prompt_id)
+ if not file_ids:
+ raise HTTPException(404, f"Generation {prompt_id!r} not found")
+
+ name = body.name.strip()
+ if not name:
+ raise HTTPException(400, "Name cannot be empty")
+ if len(name) > 100:
+ raise HTTPException(400, "Name too long (max 100 chars)")
+
+ person_id, _ = face_db_mod.get_or_create_person(name)
+
+ # Only add if not already tagged on any file of this generation
+ existing = face_db_mod.get_persons_for_source_ids(file_ids, "output")
+ if not any(p["id"] == person_id for p in existing):
+ face_db_mod.insert_detection(
+ source_type="output",
+ source_id=file_ids[0],
+ embedding=None,
+ bbox={},
+ face_index=0,
+ person_id=person_id,
+ )
+
+ return {"person_id": person_id, "name": name}
+
+
+@router.delete("/{prompt_id}/persons/{person_id}")
+async def remove_generation_person(
+ prompt_id: str,
+ person_id: int,
+ _: dict = Depends(require_admin),
+):
+ """Remove all face detections linking a person to this generation's output files."""
+ import generation_db as gen_db
+ import face_db as face_db_mod
+
+ file_ids = gen_db.get_file_ids_for_prompt(prompt_id)
+ if not file_ids:
+ raise HTTPException(404, f"Generation {prompt_id!r} not found")
+
+ face_db_mod.remove_person_from_source_ids(file_ids, person_id, "output")
+ return {"ok": True}
@router.get("/{prompt_id}/images")
@@ -101,6 +216,11 @@ async def get_history_file(
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else total - 1
end = min(end, total - 1)
+ if start < 0 or start > end:
+ return Response(
+ status_code=416,
+ headers={"Content-Range": f"bytes */{total}"},
+ )
chunk = data[start : end + 1]
return Response(
content=chunk,
@@ -110,6 +230,7 @@ async def get_history_file(
"Content-Range": f"bytes {start}-{end}/{total}",
"Accept-Ranges": "bytes",
"Content-Length": str(len(chunk)),
+ "Cache-Control": "private",
},
)
@@ -119,25 +240,62 @@ async def get_history_file(
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total),
+ "Cache-Control": "private",
},
)
+class _ShareRequest(BaseModel):
+ is_public: bool = False
+ expires_in_hours: Optional[float] = None
+ max_views: Optional[int] = None
+
+
@router.post("/{prompt_id}/share")
-async def create_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
+async def create_generation_share(
+ prompt_id: str,
+ body: _ShareRequest = Body(default_factory=_ShareRequest),
+ user: dict = Depends(require_auth),
+):
"""Create a share token for a generation. Only the owner may share."""
- # Use the same 404-for-everything helper to avoid leaking prompt_id existence
_assert_owns(prompt_id, user)
- from generation_db import create_share as db_create_share
- token = db_create_share(prompt_id, user["sub"])
- return {"share_token": token}
+ from generation_db import (
+ create_share as db_create_share,
+ get_active_share_for_prompt as db_get_active_share,
+ )
+ existing = db_get_active_share(prompt_id, user["sub"])
+ if existing:
+ raise HTTPException(409, "A share already exists — revoke it first")
+ if body.is_public and body.expires_in_hours is None and body.max_views is None:
+ raise HTTPException(400, "Public shares must have at least one expiry condition (time or view limit)")
+ if body.max_views is not None and body.max_views < 1:
+ raise HTTPException(400, "max_views must be >= 1")
+ if body.expires_in_hours is not None and body.expires_in_hours <= 0:
+ raise HTTPException(400, "expires_in_hours must be > 0")
+ expires_at = None
+ if body.expires_in_hours is not None:
+ expires_at = (datetime.now(timezone.utc) + timedelta(hours=body.expires_in_hours)).isoformat()
+ share = db_create_share(
+ prompt_id, user["sub"],
+ is_public=body.is_public,
+ expires_at=expires_at,
+ max_views=body.max_views,
+ )
+ return {
+ "share_token": share["share_token"],
+ "is_public": bool(share["is_public"]),
+ "expires_at": share["expires_at"],
+ "max_views": share["max_views"],
+ }
@router.delete("/{prompt_id}/share")
async def revoke_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
- """Revoke a share token for a generation. Only the owner may revoke."""
+ """Revoke a share token for a generation. Admins can revoke any share."""
from generation_db import revoke_share as db_revoke_share
- deleted = db_revoke_share(prompt_id, user["sub"])
+ # Admins pass None for owner_label to delete by prompt_id alone
+ owner_label = None if user.get("admin") else user["sub"]
+ deleted = db_revoke_share(prompt_id, owner_label)
if not deleted:
raise HTTPException(404, "No active share found for this generation")
return {"ok": True}
diff --git a/web/routers/inputs_router.py b/web/routers/inputs_router.py
index a5ba476..917290d 100644
--- a/web/routers/inputs_router.py
+++ b/web/routers/inputs_router.py
@@ -3,10 +3,11 @@ from __future__ import annotations
import logging
import mimetypes
+import re
from pathlib import Path
from typing import Optional
-from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, UploadFile
+from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import Response
from web.auth import require_auth
@@ -15,10 +16,38 @@ from web.deps import get_config, get_user_registry
router = APIRouter()
logger = logging.getLogger(__name__)
+_ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
+_MAX_UPLOAD_BYTES = 50 * 1024 * 1024 # 50 MB
+_SAFE_SLOT_RE = re.compile(r'^[a-zA-Z0-9_\-]+$')
+
+
+def _validate_slot_key(slot_key: str) -> None:
+ if not _SAFE_SLOT_RE.match(slot_key):
+ raise HTTPException(400, "slot_key may only contain letters, digits, hyphens, underscores")
+
@router.get("")
-async def list_inputs(_: dict = Depends(require_auth)):
- """List all input images (Discord + web uploads)."""
+async def list_inputs(
+ _: dict = Depends(require_auth),
+ persons: list[str] = Query(default=[], alias="persons", description="Filter by person name/alias substring (repeatable)"),
+):
+ """List all input images (Discord + web uploads). Optionally filter by persons."""
+ active_persons = [p.strip() for p in persons if p.strip()]
+ if active_persons:
+ import face_db as face_db_mod
+ from input_image_db import get_images_by_ids
+ all_ids: set[int] = set()
+ for p in active_persons:
+ ids = face_db_mod.get_source_ids_for_person_query(p, "input")
+ all_ids.update(ids)
+ images = list(get_images_by_ids(list(all_ids))) if all_ids else []
+ if images:
+ person_map = face_db_mod.get_persons_for_source_id_map(
+ [img["id"] for img in images], "input"
+ )
+ for img in images:
+ img["detected_persons"] = person_map.get(img["id"], [])
+ return images
from input_image_db import get_all_images
rows = get_all_images()
return [dict(r) for r in rows]
@@ -44,8 +73,16 @@ async def upload_input(
if config is None:
raise HTTPException(503, "Config not available")
+ if slot_key:
+ _validate_slot_key(slot_key)
+
data = await file.read()
filename = file.filename or "upload.png"
+ ext = Path(filename).suffix.lower()
+ if ext not in _ALLOWED_EXTS:
+ raise HTTPException(415, f"Unsupported file type '{ext}'. Allowed: {sorted(_ALLOWED_EXTS)}")
+ if len(data) > _MAX_UPLOAD_BYTES:
+ raise HTTPException(413, "File too large (max 50 MB)")
from input_image_db import upsert_image, activate_image_for_slot
row_id = upsert_image(
@@ -72,7 +109,30 @@ async def upload_input(
if comfy:
comfy.state_manager.set_override(slot_key, activated_filename)
- return {"id": row_id, "filename": filename, "slot_key": slot_key, "activated_filename": activated_filename}
+ # Face scan — runs synchronously here (~1-2 s); unknown faces returned to UI
+ pending_faces: list[dict] = []
+ try:
+ from face_service import get_face_service
+ import face_db as _face_db
+ _face_db.init_db()
+ svc = get_face_service()
+ if svc.available:
+ results = await svc.scan_input_image(row_id, data)
+ pending_faces = [
+ {"detection_id": r.detection_id, "face_index": r.face_index, "bbox": r.bbox}
+ for r in results
+ if r.matched_person_id is None
+ ]
+ except Exception as exc:
+ logger.warning("Face scan failed for upload row_id=%d: %s", row_id, exc)
+
+ return {
+ "id": row_id,
+ "filename": filename,
+ "slot_key": slot_key,
+ "activated_filename": activated_filename,
+ "pending_faces": pending_faces,
+ }
@router.post("/{row_id}/activate")
@@ -92,6 +152,7 @@ async def activate_input(
raise HTTPException(404, "Image not found")
user_label: str = user["sub"]
+ _validate_slot_key(slot_key)
namespaced_key = f"{user_label}_{slot_key}"
try:
diff --git a/web/routers/share_router.py b/web/routers/share_router.py
index 9b06caa..0f3f645 100644
--- a/web/routers/share_router.py
+++ b/web/routers/share_router.py
@@ -2,28 +2,40 @@
from __future__ import annotations
import base64
+from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
-from web.auth import require_auth
+from web.auth import optional_auth
router = APIRouter()
+def _auth_gate(gen: dict, user: Optional[dict]) -> None:
+ """Raise 401 if the share is private and the user is not authenticated."""
+ if not gen.get("is_public") and user is None:
+ raise HTTPException(401, "Authentication required")
+
+
@router.get("/{token}")
-async def get_share(token: str, _: dict = Depends(require_auth)):
- """Fetch share metadata and images. Any authenticated user may view a valid share link."""
+async def get_share(token: str, user: Optional[dict] = Depends(optional_auth)):
+ """Fetch share metadata and images. Public shares require no login; private shares require auth."""
from generation_db import get_share_by_token, get_files
gen = get_share_by_token(token)
if gen is None:
- raise HTTPException(404, "Share not found or revoked")
+ raise HTTPException(404, "Share not found, expired, or revoked")
+ _auth_gate(gen, user)
files = get_files(gen["prompt_id"])
return {
"prompt_id": gen["prompt_id"],
"created_at": gen["created_at"],
"overrides": gen["overrides"],
"seed": gen["seed"],
+ "is_public": bool(gen["is_public"]),
+ "expires_at": gen["expires_at"],
+ "max_views": gen["max_views"],
+ "view_count": gen["view_count"],
"images": [
{
"filename": f["filename"],
@@ -40,13 +52,14 @@ async def get_share_file(
token: str,
filename: str,
request: Request,
- _: dict = Depends(require_auth),
+ user: Optional[dict] = Depends(optional_auth),
):
"""Stream a single output file via share token, with HTTP range support for video seeking."""
- from generation_db import get_share_by_token, get_files
- gen = get_share_by_token(token)
+ from generation_db import get_share_meta, get_files
+ gen = get_share_meta(token)
if gen is None:
- raise HTTPException(404, "Share not found or revoked")
+ raise HTTPException(404, "Share not found, expired, or revoked")
+ _auth_gate(gen, user)
files = get_files(gen["prompt_id"])
matched = next((f for f in files if f["filename"] == filename), None)
if matched is None:
@@ -63,6 +76,11 @@ async def get_share_file(
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else total - 1
end = min(end, total - 1)
+ if start < 0 or start > end:
+ return Response(
+ status_code=416,
+ headers={"Content-Range": f"bytes */{total}"},
+ )
chunk = data[start : end + 1]
return Response(
content=chunk,
@@ -72,6 +90,7 @@ async def get_share_file(
"Content-Range": f"bytes {start}-{end}/{total}",
"Accept-Ranges": "bytes",
"Content-Length": str(len(chunk)),
+ "Cache-Control": "public, max-age=3600" if gen.get("is_public") else "private",
},
)
@@ -81,5 +100,6 @@ async def get_share_file(
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total),
+ "Cache-Control": "public, max-age=3600" if gen.get("is_public") else "private",
},
)