From 6004b000a769b0c5d8478a93a3c1a166e917e359 Mon Sep 17 00:00:00 2001 From: "Khoa (Revenovich) Tran Gia" Date: Sat, 7 Mar 2026 21:49:16 +0700 Subject: [PATCH] manual submit --- bot.py | 14 + comfy_client.py | 22 +- commands/input_images.py | 127 ++ face_db.py | 715 ++++++++++ face_service.py | 565 ++++++++ frontend/package-lock.json | 59 + frontend/package.json | 6 +- frontend/src/App.tsx | 9 + frontend/src/api/client.ts | 211 ++- .../src/components/DynamicWorkflowForm.tsx | 61 +- frontend/src/components/FaceIdentifyModal.tsx | 183 +++ frontend/src/components/GlassCard.tsx | 21 + frontend/src/components/Layout.tsx | 224 ++- frontend/src/hooks/useStatus.ts | 16 +- frontend/src/hooks/useWebSocket.ts | 8 +- frontend/src/index.css | 45 +- frontend/src/pages/AdminPage.tsx | 44 +- frontend/src/pages/FacesPage.tsx | 1240 +++++++++++++++++ frontend/src/pages/GeneratePage.tsx | 145 +- frontend/src/pages/HistoryPage.tsx | 675 +++++++-- frontend/src/pages/InputImagesPage.tsx | 264 +++- frontend/src/pages/LoginPage.tsx | 35 +- frontend/src/pages/PresetsPage.tsx | 55 +- frontend/src/pages/ServerPage.tsx | 43 +- frontend/src/pages/SharePage.tsx | 98 +- frontend/src/pages/StatusPage.tsx | 159 ++- frontend/src/pages/WorkflowPage.tsx | 102 +- frontend/tailwind.config.js | 17 +- generation_db.py | 226 ++- input_image_db.py | 13 + media_uploader.py | 2 +- sync_faces.py | 131 ++ web/app.py | 11 +- web/auth.py | 24 +- web/routers/faces_router.py | 451 ++++++ web/routers/generate_router.py | 94 +- web/routers/history_router.py | 188 ++- web/routers/inputs_router.py | 69 +- web/routers/share_router.py | 36 +- 39 files changed, 5794 insertions(+), 614 deletions(-) create mode 100644 face_db.py create mode 100644 face_service.py create mode 100644 frontend/src/components/FaceIdentifyModal.tsx create mode 100644 frontend/src/components/GlassCard.tsx create mode 100644 frontend/src/pages/FacesPage.tsx create mode 100644 sync_faces.py create mode 100644 web/routers/faces_router.py diff --git a/bot.py b/bot.py index 370b4e1..0c06ea4 100644 --- a/bot.py +++ b/bot.py @@ -143,6 +143,14 @@ def _try_autoload_last_workflow(client: ComfyClient) -> None: if not last_wf: return wf_path = _PROJECT_ROOT / "workflows" / last_wf + # Guard against path traversal in the persisted state file + try: + safe_root = (_PROJECT_ROOT / "workflows").resolve() + if not wf_path.resolve().is_relative_to(safe_root): + logger.warning("Blocked path traversal attempt in last_workflow_file: %r", last_wf) + return + except Exception: + return if not wf_path.exists(): logger.warning("Last workflow file not found: %s", wf_path) return @@ -190,6 +198,12 @@ async def main() -> None: init_db() generation_db.init_db(_PROJECT_ROOT / "generation_history.db") + try: + import face_db as _face_db + _face_db.init_db() + logger.info("Face DB initialized") + except Exception as _exc: + logger.warning("Face DB init failed (non-fatal): %s", _exc) register_all_commands(bot, config) logger.info("All commands registered") diff --git a/comfy_client.py b/comfy_client.py index e0cf712..e101fe6 100644 --- a/comfy_client.py +++ b/comfy_client.py @@ -265,7 +265,8 @@ class ComfyClient: prompt_id, source, user_label, overrides, seed ) for i, img_data in enumerate(images): - generation_db.record_file(gen_id, f"image_{i:04d}.png", img_data) + file_id = generation_db.record_file(gen_id, f"image_{i:04d}.png", img_data) + self._schedule_face_scan("image", file_id, img_data) if videos and self.output_path: for vid in videos: vname = vid.get("video_name", "") @@ -276,7 +277,9 @@ class ComfyClient: else _Path(self.output_path) / vname ) try: - generation_db.record_file(gen_id, vname, vpath.read_bytes()) + vid_data = vpath.read_bytes() + file_id = generation_db.record_file(gen_id, vname, vid_data) + self._schedule_face_scan("video", file_id, vid_data) except OSError as exc: self.logger.warning( "Could not read video for DB storage: %s: %s", vpath, exc @@ -284,6 +287,21 @@ class ComfyClient: except Exception as exc: self.logger.warning("Failed to record generation to DB: %s", exc) + def _schedule_face_scan(self, media_type: str, file_id: int, data: bytes) -> None: + """Fire-and-forget background face scan for a generated output file.""" + try: + from face_service import get_face_service + svc = get_face_service() + if not svc.available: + return + loop = asyncio.get_running_loop() + if media_type == "image": + loop.create_task(svc.scan_output_image(file_id, data)) + elif media_type == "video": + loop.create_task(svc.scan_video(file_id, data)) + except Exception as exc: + self.logger.warning("Could not schedule face scan: %s", exc) + # ------------------------------------------------------------------ # Public generation API # ------------------------------------------------------------------ diff --git a/commands/input_images.py b/commands/input_images.py index 2c41882..cc09e92 100644 --- a/commands/input_images.py +++ b/commands/input_images.py @@ -14,6 +14,7 @@ stored in the SQLite database. from __future__ import annotations +import asyncio import io import logging from pathlib import Path @@ -30,6 +31,127 @@ from input_image_db import ( logger = logging.getLogger(__name__) + +async def _identify_faces_discord( + bot, + message: discord.Message, + row_id: int, + image_bytes: bytes, +) -> None: + """ + After an input image is registered, scan it for faces and prompt the + uploader to identify any that weren't auto-matched. + + Only the original uploader's replies are accepted (author check on wait_for). + The loop runs at most 3 rounds to resolve deduplication conflicts. + """ + try: + from face_service import get_face_service + import face_db + + svc = get_face_service() + if not svc.available: + return + + face_db.init_db() + results = await svc.scan_input_image(row_id, image_bytes) + unknown = [r for r in results if r.matched_person_id is None] + if not unknown: + return + + # Build initial prompt with face crops as attachments + files = [] + for r in unknown: + crop = svc.get_face_crop(r.detection_id) + if crop: + files.append(discord.File(io.BytesIO(crop), filename=f"face_{r.face_index}.jpg")) + + n = len(unknown) + prompt_text = ( + f"\U0001f50d Found {n} new face(s) in your image. " + f"Reply with names in order (comma-separated): `Name1, Name2, ...`\n" + f"_(or ignore to skip identification)_" + ) + bot_msg = await message.channel.send(prompt_text, files=files) + + def _check(m: discord.Message) -> bool: + return ( + m.reference is not None + and m.reference.message_id == bot_msg.id + and m.author.id == message.author.id + ) + + pending = list(unknown) # detections still needing names + + for _round in range(3): + try: + reply = await bot.wait_for("message", check=_check, timeout=120) + except asyncio.TimeoutError: + return + + raw_names = [n.strip() for n in reply.content.split(",")] + if len(raw_names) < len(pending): + raw_names += [""] * (len(pending) - len(raw_names)) + + # Check for conflicts (name exists but user said a new name) + conflicts: list[tuple[int, str]] = [] # (index in pending, name) + for idx, (det, name) in enumerate(zip(pending, raw_names)): + if name and face_db.person_name_exists(name): + conflicts.append((idx, name)) + + if conflicts: + conflict_lines = "\n".join( + f"Face {pending[idx].face_index + 1} → `{name}`" + for idx, name in conflicts + ) + warn_msg = await reply.reply( + f"\u26a0\ufe0f These names already exist:\n{conflict_lines}\n\n" + f"Reply `same` for any that should link to the **existing** person, " + f"or provide a different name — one value per conflicting face " + f"(comma-separated, in the same order as listed above)." + ) + bot_msg = warn_msg + + # Update check to listen for reply to the new warning message + def _check_conflict(m: discord.Message) -> bool: + return ( + m.reference is not None + and m.reference.message_id == warn_msg.id + and m.author.id == message.author.id + ) + + try: + conflict_reply = await bot.wait_for( + "message", check=_check_conflict, timeout=120 + ) + except asyncio.TimeoutError: + return + + resolved = [v.strip() for v in conflict_reply.content.split(",")] + # Apply resolved names back to the original name list + for list_pos, (pending_idx, old_name) in enumerate(conflicts): + if list_pos < len(resolved): + val = resolved[list_pos] + raw_names[pending_idx] = old_name if val.lower() == "same" else val + + # Apply names — skip blanks + confirmed: list[str] = [] + for det, name in zip(pending, raw_names): + if not name: + continue + use_existing = face_db.person_name_exists(name) + person_id, _ = face_db.get_or_create_person(name) + face_db.link_detection_to_person(det.detection_id, person_id) + status = "linked to existing" if use_existing else "new" + confirmed.append(f"{name} ({status})") + + if confirmed: + await reply.reply(f"\u2705 Identified: {', '.join(confirmed)}") + return + + except Exception as exc: + logger.warning("Face identification flow failed: %s", exc) + IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp"} @@ -92,6 +214,11 @@ async def _register_attachment(bot, config, message: discord.Message, attachment logger.info("[_register_attachment] Done") await reply.edit(view=view) + # Background face scan + optional identification prompt + asyncio.create_task( + _identify_faces_discord(bot, message, row_id, original_data) + ) + def setup_input_image_commands(bot, config=None): """Register input image commands and the on_message listener.""" diff --git a/face_db.py b/face_db.py new file mode 100644 index 0000000..f86c78c --- /dev/null +++ b/face_db.py @@ -0,0 +1,715 @@ +""" +face_db.py +========== + +SQLite persistence for face detection and identity data. + +Two tables: + persons : one row per named person + face_detections : one detected face per row, linked to source media +""" + +from __future__ import annotations + +import json +import sqlite3 +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np + +_DB_PATH: Path = Path(__file__).parent / "faces.db" + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS persons ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + created_at TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_persons_name ON persons(LOWER(name)); + +CREATE TABLE IF NOT EXISTS person_aliases ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + person_id INTEGER NOT NULL REFERENCES persons(id) ON DELETE CASCADE, + alias TEXT NOT NULL, + created_at TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_aliases_alias ON person_aliases(LOWER(alias)); + +CREATE TABLE IF NOT EXISTS face_groups ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + label TEXT, + threshold REAL NOT NULL, + is_manual INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS face_detections ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + person_id INTEGER REFERENCES persons(id) ON DELETE SET NULL, + source_type TEXT NOT NULL, + source_id INTEGER NOT NULL, + frame_index INTEGER NOT NULL DEFAULT 0, + face_index INTEGER NOT NULL DEFAULT 0, + embedding BLOB, + bbox_json TEXT, + created_at TEXT NOT NULL, + identified_at TEXT +); +CREATE INDEX IF NOT EXISTS idx_fd_source ON face_detections(source_type, source_id); +CREATE INDEX IF NOT EXISTS idx_fd_person ON face_detections(person_id); +""" + + +def _apply_migrations(conn: sqlite3.Connection) -> None: + """Apply schema migrations that cannot be expressed in CREATE TABLE IF NOT EXISTS.""" + # Migration 0: Add group_id column + try: + conn.execute( + "ALTER TABLE face_detections ADD COLUMN group_id INTEGER REFERENCES face_groups(id) ON DELETE SET NULL" + ) + conn.commit() + except sqlite3.OperationalError: + pass # column already exists + conn.execute("CREATE INDEX IF NOT EXISTS idx_fd_group ON face_detections(group_id)") + conn.commit() + + # Migration A: Make embedding nullable (rebuild table if NOT NULL constraint exists) + ddl_row = conn.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='face_detections'" + ).fetchone() + import re as _re + _ddl_normalized = " ".join((ddl_row["sql"] or "").split()) if ddl_row else "" + if "embedding BLOB NOT NULL" in _ddl_normalized: + conn.execute(""" + CREATE TABLE face_detections_new ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + person_id INTEGER REFERENCES persons(id) ON DELETE SET NULL, + source_type TEXT NOT NULL, + source_id INTEGER NOT NULL, + frame_index INTEGER NOT NULL DEFAULT 0, + face_index INTEGER NOT NULL DEFAULT 0, + embedding BLOB, + bbox_json TEXT, + created_at TEXT NOT NULL, + identified_at TEXT, + group_id INTEGER REFERENCES face_groups(id) ON DELETE SET NULL + ) + """) + conn.execute( + "INSERT INTO face_detections_new " + "SELECT id, person_id, source_type, source_id, frame_index, face_index, " + "embedding, bbox_json, created_at, identified_at, group_id FROM face_detections" + ) + conn.execute("DROP TABLE face_detections") + conn.execute("ALTER TABLE face_detections_new RENAME TO face_detections") + conn.execute("CREATE INDEX IF NOT EXISTS idx_fd_source ON face_detections(source_type, source_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_fd_person ON face_detections(person_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_fd_group ON face_detections(group_id)") + conn.commit() + + # Migration B: NULL out all existing output embeddings (optimize storage) + result = conn.execute( + "UPDATE face_detections SET embedding = NULL WHERE source_type = 'output' AND embedding IS NOT NULL" + ) + conn.commit() + if result.rowcount: + import logging as _logging + _logging.getLogger(__name__).info("Migration B: cleared %d output embedding(s)", result.rowcount) + + +def _connect(db_path: Path | None = None) -> sqlite3.Connection: + path = db_path if db_path is not None else _DB_PATH + conn = sqlite3.connect(str(path), check_same_thread=False) + conn.row_factory = sqlite3.Row + return conn + + +def init_db(db_path: Path = _DB_PATH) -> None: + """Create tables if they don't exist and apply schema migrations.""" + global _DB_PATH + _DB_PATH = db_path + with _connect(db_path) as conn: + conn.executescript(_SCHEMA) + conn.commit() + _apply_migrations(conn) + + +def get_or_create_person(name: str) -> tuple[int, bool]: + """Find or create a person by name (case-insensitive). Resolves aliases. Returns (person_id, created).""" + with _connect() as conn: + row = conn.execute( + "SELECT id FROM persons WHERE LOWER(name) = LOWER(?)", (name,) + ).fetchone() + if row: + return row["id"], False + row = conn.execute( + "SELECT person_id FROM person_aliases WHERE LOWER(alias) = LOWER(?)", (name,) + ).fetchone() + if row: + return row["person_id"], False + created_at = datetime.now(timezone.utc).isoformat() + cur = conn.execute( + "INSERT INTO persons (name, created_at) VALUES (?, ?)", + (name, created_at), + ) + conn.commit() + return cur.lastrowid, True # type: ignore[return-value] + + +def person_name_exists(name: str) -> bool: + """Return True if a person with that name or alias (case-insensitive) already exists.""" + with _connect() as conn: + row = conn.execute( + """SELECT 1 FROM persons WHERE LOWER(name) = LOWER(?) + UNION + SELECT 1 FROM person_aliases WHERE LOWER(alias) = LOWER(?) + LIMIT 1""", + (name, name), + ).fetchone() + return row is not None + + +def get_person_by_name(name: str) -> dict | None: + """Return person row dict by name (case-insensitive), or None.""" + with _connect() as conn: + row = conn.execute( + "SELECT id, name, created_at FROM persons WHERE LOWER(name) = LOWER(?)", (name,) + ).fetchone() + return dict(row) if row else None + + +def list_persons() -> list[dict]: + """Return all persons sorted by name, each with their aliases list and face_count.""" + with _connect() as conn: + persons = conn.execute( + """SELECT p.id, p.name, p.created_at, + (SELECT COUNT(*) FROM face_detections WHERE person_id = p.id) AS face_count + FROM persons p ORDER BY p.name""" + ).fetchall() + result = [] + for p in persons: + d = dict(p) + aliases = conn.execute( + "SELECT id, alias FROM person_aliases WHERE person_id = ? ORDER BY alias", + (d["id"],), + ).fetchall() + d["aliases"] = [dict(a) for a in aliases] + result.append(d) + return result + + +def get_unidentified_input_detections(limit: int = 50, offset: int = 0) -> tuple[list[dict], int]: + """Return paginated ungrouped unidentified face detections from input images, plus total count.""" + with _connect() as conn: + total = conn.execute( + "SELECT COUNT(*) FROM face_detections WHERE person_id IS NULL AND group_id IS NULL AND source_type = 'input'" + ).fetchone()[0] + rows = conn.execute( + """SELECT id, source_id, face_index, bbox_json, created_at + FROM face_detections + WHERE person_id IS NULL AND group_id IS NULL AND source_type = 'input' + ORDER BY id + LIMIT ? OFFSET ?""", + (limit, offset), + ).fetchall() + return [dict(r) for r in rows], total + + +def add_alias(person_id: int, alias: str) -> tuple[int, bool]: + """Add an alias to a person. Returns (alias_id, created). Raises ValueError if alias taken.""" + if person_name_exists(alias): + raise ValueError(f"Name or alias '{alias}' is already taken") + created_at = datetime.now(timezone.utc).isoformat() + with _connect() as conn: + cur = conn.execute( + "INSERT INTO person_aliases (person_id, alias, created_at) VALUES (?, ?, ?)", + (person_id, alias, created_at), + ) + conn.commit() + return cur.lastrowid, True # type: ignore[return-value] + + +def remove_alias(alias_id: int) -> None: + """Delete an alias row by id.""" + with _connect() as conn: + conn.execute("DELETE FROM person_aliases WHERE id = ?", (alias_id,)) + conn.commit() + + +def count_detections_for_person(person_id: int) -> int: + """Return number of detections linked to this person.""" + with _connect() as conn: + row = conn.execute( + "SELECT COUNT(*) FROM face_detections WHERE person_id = ?", (person_id,) + ).fetchone() + return row[0] + + +def insert_detection( + source_type: str, + source_id: int, + embedding: "np.ndarray | None", + bbox: dict, + frame_index: int = 0, + face_index: int = 0, + person_id: int | None = None, +) -> int: + """Insert a face detection row. Returns the new row id.""" + embedding_bytes = embedding.astype(np.float32).tobytes() if embedding is not None else None + bbox_json = json.dumps(bbox) + created_at = datetime.now(timezone.utc).isoformat() + identified_at = created_at if person_id is not None else None + with _connect() as conn: + cur = conn.execute( + """ + INSERT INTO face_detections + (person_id, source_type, source_id, frame_index, face_index, + embedding, bbox_json, created_at, identified_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (person_id, source_type, source_id, frame_index, face_index, + embedding_bytes, bbox_json, created_at, identified_at), + ) + conn.commit() + return cur.lastrowid # type: ignore[return-value] + + +def link_detection_to_person(detection_id: int, person_id: int) -> None: + """Associate a detection with a named person.""" + identified_at = datetime.now(timezone.utc).isoformat() + with _connect() as conn: + conn.execute( + "UPDATE face_detections SET person_id = ?, identified_at = ? WHERE id = ?", + (person_id, identified_at, detection_id), + ) + conn.commit() + prune_person_output_embeddings(person_id) + + +def prune_person_output_embeddings(person_id: int, max_keep: int = 5) -> int: + """Delete oldest output embeddings for a person if they exceed max_keep. + Input embeddings are never pruned. Returns number deleted.""" + with _connect() as conn: + rows = conn.execute( + """SELECT id FROM face_detections + WHERE person_id = ? AND source_type = 'output' AND embedding IS NOT NULL + ORDER BY id ASC""", + (person_id,), + ).fetchall() + excess = rows[:max(0, len(rows) - max_keep)] + if not excess: + return 0 + ids = [r["id"] for r in excess] + placeholders = ",".join("?" * len(ids)) + conn.execute( + f"UPDATE face_detections SET embedding = NULL WHERE id IN ({placeholders})", ids + ) + conn.commit() + return len(excess) + + +def get_detection(detection_id: int) -> dict | None: + """Return a single face_detection row by id, or None.""" + with _connect() as conn: + row = conn.execute( + "SELECT * FROM face_detections WHERE id = ?", (detection_id,) + ).fetchone() + return dict(row) if row else None + + +def get_detections_for_source(source_type: str, source_id: int) -> list[dict]: + """Return all face detections for a given source.""" + with _connect() as conn: + rows = conn.execute( + "SELECT * FROM face_detections WHERE source_type = ? AND source_id = ?", + (source_type, source_id), + ).fetchall() + return [dict(r) for r in rows] + + +def get_all_embeddings() -> list[dict]: + """Return all identified detections with their embeddings as np.ndarray.""" + with _connect() as conn: + rows = conn.execute( + "SELECT id, person_id, embedding FROM face_detections " + "WHERE person_id IS NOT NULL AND embedding IS NOT NULL" + ).fetchall() + result = [] + for row in rows: + emb = np.frombuffer(bytes(row["embedding"]), dtype=np.float32).copy() + result.append({"id": row["id"], "person_id": row["person_id"], "embedding": emb}) + return result + + +# --------------------------------------------------------------------------- +# Face group CRUD +# --------------------------------------------------------------------------- + + +def create_group(threshold: float, label: str | None = None, is_manual: bool = False) -> int: + """Insert a new face_groups row and return its id.""" + created_at = datetime.now(timezone.utc).isoformat() + with _connect() as conn: + cur = conn.execute( + "INSERT INTO face_groups (label, threshold, is_manual, created_at) VALUES (?, ?, ?, ?)", + (label, threshold, int(is_manual), created_at), + ) + conn.commit() + return cur.lastrowid # type: ignore[return-value] + + +def assign_detection_to_group(detection_id: int, group_id: int | None) -> None: + """Set or clear the group_id on a face_detections row.""" + with _connect() as conn: + conn.execute( + "UPDATE face_detections SET group_id = ? WHERE id = ?", + (group_id, detection_id), + ) + conn.commit() + + +def clear_all_groups() -> None: + """Remove all groups: NULL all group_id references then DELETE all face_groups rows.""" + with _connect() as conn: + conn.execute("UPDATE face_detections SET group_id = NULL WHERE group_id IS NOT NULL") + conn.execute("DELETE FROM face_groups") + conn.commit() + + +def get_groups_with_detections() -> list[dict]: + """Return groups that have ≥ 2 unidentified detections, with detection_ids and preview_ids.""" + with _connect() as conn: + groups = conn.execute( + "SELECT id, label, threshold, is_manual, created_at FROM face_groups ORDER BY id" + ).fetchall() + result = [] + for g in groups: + det_rows = conn.execute( + "SELECT id FROM face_detections WHERE group_id = ? AND person_id IS NULL ORDER BY id", + (g["id"],), + ).fetchall() + det_ids = [r["id"] for r in det_rows] + if len(det_ids) < 2: + continue + result.append({ + **dict(g), + "count": len(det_ids), + "detection_ids": det_ids, + "preview_ids": det_ids[:4], + }) + return result + + +def get_group_detections(group_id: int) -> list[dict]: + """Return unidentified detections belonging to a group.""" + with _connect() as conn: + rows = conn.execute( + """SELECT id, source_type, face_index, created_at + FROM face_detections + WHERE group_id = ? AND person_id IS NULL + ORDER BY id""", + (group_id,), + ).fetchall() + return [dict(r) for r in rows] + + +def merge_groups(keep_id: int, discard_id: int) -> None: + """Move all detections from discard_id to keep_id and delete the discard group.""" + with _connect() as conn: + conn.execute( + "UPDATE face_detections SET group_id = ? WHERE group_id = ?", + (keep_id, discard_id), + ) + conn.execute("DELETE FROM face_groups WHERE id = ?", (discard_id,)) + conn.commit() + + +def remove_detection_from_group(detection_id: int) -> None: + """Unassign a detection from its group (set group_id = NULL).""" + assign_detection_to_group(detection_id, None) + + +def delete_singleton_groups() -> None: + """Delete groups that have fewer than 2 unidentified members, NULL-ing their refs first.""" + with _connect() as conn: + rows = conn.execute( + """SELECT fg.id FROM face_groups fg + LEFT JOIN face_detections fd ON fd.group_id = fg.id AND fd.person_id IS NULL + GROUP BY fg.id + HAVING COUNT(fd.id) < 2""" + ).fetchall() + for row in rows: + gid = row["id"] + conn.execute("UPDATE face_detections SET group_id = NULL WHERE group_id = ?", (gid,)) + conn.execute("DELETE FROM face_groups WHERE id = ?", (gid,)) + conn.commit() + + +def delete_group(group_id: int) -> None: + """Delete a single face_groups row (does not NULL detections — caller must have already cleared them).""" + with _connect() as conn: + conn.execute("DELETE FROM face_groups WHERE id = ?", (group_id,)) + conn.commit() + + +def get_unidentified_embeddings() -> list[dict]: + """Return embeddings for all unidentified detections (both source types).""" + with _connect() as conn: + rows = conn.execute( + "SELECT id, embedding FROM face_detections WHERE person_id IS NULL AND embedding IS NOT NULL" + ).fetchall() + result = [] + for row in rows: + emb = np.frombuffer(bytes(row["embedding"]), dtype=np.float32).copy() + result.append({"id": row["id"], "embedding": emb}) + return result + + +def get_all_group_embeddings_with_threshold() -> list[dict]: + """Return per-group embedding lists along with the group threshold.""" + with _connect() as conn: + groups = conn.execute("SELECT id, threshold FROM face_groups").fetchall() + result = [] + for g in groups: + rows = conn.execute( + "SELECT embedding FROM face_detections " + "WHERE group_id = ? AND person_id IS NULL AND embedding IS NOT NULL", + (g["id"],), + ).fetchall() + embeddings = [ + np.frombuffer(bytes(r["embedding"]), dtype=np.float32).copy() + for r in rows + ] + if embeddings: + result.append({ + "group_id": g["id"], + "threshold": g["threshold"], + "embeddings": embeddings, + }) + return result + + +def get_person_embeddings(person_id: int) -> list[np.ndarray]: + """Return all embedding arrays for a given person (only rows with non-null embedding).""" + with _connect() as conn: + rows = conn.execute( + "SELECT embedding FROM face_detections WHERE person_id = ? AND embedding IS NOT NULL", + (person_id,), + ).fetchall() + return [np.frombuffer(bytes(r["embedding"]), dtype=np.float32).copy() for r in rows] + + +def get_ungrouped_unidentified_embeddings() -> list[dict]: + """Return embeddings for unidentified detections that have no group assigned.""" + with _connect() as conn: + rows = conn.execute( + "SELECT id, embedding FROM face_detections " + "WHERE person_id IS NULL AND group_id IS NULL AND embedding IS NOT NULL" + ).fetchall() + result = [] + for row in rows: + emb = np.frombuffer(bytes(row["embedding"]), dtype=np.float32).copy() + result.append({"id": row["id"], "embedding": emb}) + return result + + +# --------------------------------------------------------------------------- +# New person management functions +# --------------------------------------------------------------------------- + + +def get_person(person_id: int) -> dict | None: + """Return {id, name, created_at, aliases: [{id, alias}]} for a person, or None.""" + with _connect() as conn: + row = conn.execute( + "SELECT id, name, created_at FROM persons WHERE id = ?", (person_id,) + ).fetchone() + if row is None: + return None + d = dict(row) + aliases = conn.execute( + "SELECT id, alias FROM person_aliases WHERE person_id = ? ORDER BY alias", + (person_id,), + ).fetchall() + d["aliases"] = [dict(a) for a in aliases] + return d + + +def get_detections_for_person( + person_id: int, limit: int = 50, offset: int = 0 +) -> tuple[list[dict], int]: + """Return paginated detections for a person and total count (no embedding column).""" + with _connect() as conn: + total = conn.execute( + "SELECT COUNT(*) FROM face_detections WHERE person_id = ?", (person_id,) + ).fetchone()[0] + rows = conn.execute( + """SELECT id, source_type, source_id, face_index, frame_index, bbox_json, + created_at, identified_at + FROM face_detections + WHERE person_id = ? + ORDER BY id + LIMIT ? OFFSET ?""", + (person_id, limit, offset), + ).fetchall() + return [dict(r) for r in rows], total + + +def unidentify_detection(detection_id: int) -> None: + """Remove the person association from a detection.""" + with _connect() as conn: + conn.execute( + "UPDATE face_detections SET person_id = NULL, identified_at = NULL WHERE id = ?", + (detection_id,), + ) + conn.commit() + + +def rename_person(person_id: int, new_name: str) -> None: + """ + Rename a person. Allows case-only renames for the same person. + Raises ValueError if new_name is taken by a different person or any alias. + """ + with _connect() as conn: + # Check if name is taken by a different person + row = conn.execute( + "SELECT id FROM persons WHERE LOWER(name) = LOWER(?)", (new_name,) + ).fetchone() + if row and row["id"] != person_id: + raise ValueError(f"Name '{new_name}' is already taken by another person") + # Check if name is taken by any alias (aliases belong to any person) + alias_row = conn.execute( + "SELECT person_id FROM person_aliases WHERE LOWER(alias) = LOWER(?)", (new_name,) + ).fetchone() + if alias_row: + raise ValueError(f"Name '{new_name}' is already used as an alias") + conn.execute("UPDATE persons SET name = ? WHERE id = ?", (new_name, person_id)) + conn.commit() + + +def delete_person(person_id: int) -> None: + """ + Delete a person: NULL all their detections, delete aliases, delete person row. + Does not rely on FK pragma (which is not enabled by default). + """ + with _connect() as conn: + conn.execute( + "UPDATE face_detections SET person_id = NULL, identified_at = NULL WHERE person_id = ?", + (person_id,), + ) + conn.execute("DELETE FROM person_aliases WHERE person_id = ?", (person_id,)) + conn.execute("DELETE FROM persons WHERE id = ?", (person_id,)) + conn.commit() + + +def remove_person_from_source_ids(source_ids: list[int], person_id: int, source_type: str = "output") -> int: + """Delete face_detections linking person_id to any of the given source_ids. Returns deleted count.""" + if not source_ids: + return 0 + placeholders = ",".join("?" * len(source_ids)) + with _connect() as conn: + cur = conn.execute( + f"DELETE FROM face_detections WHERE source_type = ? AND person_id = ? AND source_id IN ({placeholders})", + (source_type, person_id, *source_ids), + ) + conn.commit() + return cur.rowcount + + +def get_persons_for_source_ids(source_ids: list[int], source_type: str = "output") -> list[dict]: + """Return distinct [{id, name}] for persons detected in the given source IDs.""" + if not source_ids: + return [] + placeholders = ",".join("?" * len(source_ids)) + with _connect() as conn: + rows = conn.execute( + f"""SELECT DISTINCT p.id, p.name + FROM face_detections fd JOIN persons p ON p.id = fd.person_id + WHERE fd.source_type = ? AND fd.source_id IN ({placeholders}) AND fd.person_id IS NOT NULL + ORDER BY p.name""", + (source_type, *source_ids), + ).fetchall() + return [dict(r) for r in rows] + + +def get_source_ids_for_person_query(name_query: str, source_type: str) -> list[int]: + """Return distinct source_id values where a person/alias matching the substring appears.""" + if not name_query.strip(): + return [] + pattern = f"%{name_query}%" + with _connect() as conn: + rows = conn.execute( + """SELECT DISTINCT fd.source_id + FROM face_detections fd + JOIN persons p ON p.id = fd.person_id + LEFT JOIN person_aliases pa ON pa.person_id = p.id + WHERE fd.source_type = ? AND fd.person_id IS NOT NULL + AND (LOWER(p.name) LIKE LOWER(?) OR LOWER(pa.alias) LIKE LOWER(?))""", + (source_type, pattern, pattern), + ).fetchall() + return [r["source_id"] for r in rows] + + +def get_persons_for_source_id_map(source_ids: list[int], source_type: str) -> dict[int, list[str]]: + """Return {source_id: [person_name, …]} for the given source IDs (identified faces only).""" + if not source_ids: + return {} + placeholders = ",".join("?" * len(source_ids)) + with _connect() as conn: + rows = conn.execute( + f"""SELECT fd.source_id, p.name + FROM face_detections fd JOIN persons p ON p.id = fd.person_id + WHERE fd.source_type = ? AND fd.source_id IN ({placeholders}) AND fd.person_id IS NOT NULL + ORDER BY fd.source_id, p.name""", + (source_type, *source_ids), + ).fetchall() + result: dict[int, list[str]] = {sid: [] for sid in source_ids} + for row in rows: + sid = row["source_id"] + name = row["name"] + if name not in result[sid]: + result[sid].append(name) + return result + + +def update_detection_embedding(detection_id: int, embedding: "np.ndarray") -> None: + """Update the embedding for an existing detection row.""" + embedding_bytes = embedding.astype(np.float32).tobytes() + with _connect() as conn: + conn.execute( + "UPDATE face_detections SET embedding = ? WHERE id = ?", + (embedding_bytes, detection_id), + ) + conn.commit() + + +def get_null_embedding_output_source_ids() -> list[int]: + """Distinct source_ids of output detections with NULL embeddings and stored bbox.""" + with _connect() as conn: + rows = conn.execute( + """SELECT DISTINCT source_id FROM face_detections + WHERE source_type = 'output' AND embedding IS NULL + AND bbox_json IS NOT NULL AND bbox_json != '{}'""" + ).fetchall() + return [r["source_id"] for r in rows] + + +def merge_persons(survivor_id: int, other_id: int) -> None: + """ + Absorb other_id into survivor_id: move all detections and aliases, then delete other_id. + Raises ValueError if both ids are the same. + """ + if survivor_id == other_id: + raise ValueError("Cannot merge person into themselves") + with _connect() as conn: + conn.execute( + "UPDATE face_detections SET person_id = ? WHERE person_id = ?", + (survivor_id, other_id), + ) + conn.execute( + "UPDATE person_aliases SET person_id = ? WHERE person_id = ?", + (survivor_id, other_id), + ) + conn.execute("DELETE FROM persons WHERE id = ?", (other_id,)) + conn.commit() diff --git a/face_service.py b/face_service.py new file mode 100644 index 0000000..f5e712f --- /dev/null +++ b/face_service.py @@ -0,0 +1,565 @@ +""" +face_service.py +=============== + +FaceService: wrapper around insightface for face detection and recognition. + +Runs CPU-bound work in a ThreadPoolExecutor(max_workers=1). +Falls back gracefully if insightface is not installed (available=False). +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import tempfile +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Optional + +import numpy as np + +logger = logging.getLogger(__name__) + +try: + from insightface.app import FaceAnalysis as _FaceAnalysis + _INSIGHTFACE_AVAILABLE = True +except ImportError: + _FaceAnalysis = None # type: ignore + _INSIGHTFACE_AVAILABLE = False + +_SIMILARITY_THRESHOLD = 0.4 + + +@dataclass +class DetectedFace: + face_index: int + bbox: dict # {x1, y1, x2, y2} + embedding: np.ndarray + crop_bytes: bytes # JPEG bytes of the face crop + + +@dataclass +class ScanResult: + detection_id: int + face_index: int + bbox: dict + matched_person_id: Optional[int] + matched_person_name: Optional[str] + + +class FaceService: + available: bool + + def __init__(self) -> None: + self.available = _INSIGHTFACE_AVAILABLE + self._executor = ThreadPoolExecutor(max_workers=1) + self._app = None + if self.available: + try: + self._app = _FaceAnalysis(providers=["CPUExecutionProvider"]) + self._app.prepare(ctx_id=0, det_size=(640, 640)) + logger.info("FaceService: insightface ready") + except Exception as exc: + logger.warning("FaceService: failed to init insightface: %s", exc) + self.available = False + + # ------------------------------------------------------------------ + # Low-level detection + # ------------------------------------------------------------------ + + def _detect_sync(self, image_bytes: bytes) -> list[DetectedFace]: + """CPU-bound: detect faces in image bytes.""" + import cv2 + arr = np.frombuffer(image_bytes, dtype=np.uint8) + try: + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + except Exception as exc: + logger.warning("FaceService: cv2.imdecode failed: %s", exc) + return [] + if img is None: + return [] + try: + faces = self._app.get(img) + except Exception as exc: + logger.warning("FaceService: face detection failed: %s", exc) + return [] + + results = [] + for i, face in enumerate(faces): + x1, y1, x2, y2 = (int(v) for v in face.bbox) + bbox = {"x1": x1, "y1": y1, "x2": x2, "y2": y2} + emb = face.normed_embedding.astype(np.float32) + + # Crop with padding + pad = 20 + h, w = img.shape[:2] + cx1 = max(0, x1 - pad) + cy1 = max(0, y1 - pad) + cx2 = min(w, x2 + pad) + cy2 = min(h, y2 + pad) + crop = img[cy1:cy2, cx1:cx2] + _, buf = cv2.imencode(".jpg", crop, [cv2.IMWRITE_JPEG_QUALITY, 85]) + crop_bytes = buf.tobytes() + + results.append(DetectedFace( + face_index=i, + bbox=bbox, + embedding=emb, + crop_bytes=crop_bytes, + )) + return results + + async def detect(self, image_bytes: bytes) -> list[DetectedFace]: + """Async face detection wrapper.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self._detect_sync, image_bytes) + + # ------------------------------------------------------------------ + # Matching + # ------------------------------------------------------------------ + + def find_best_match( + self, + embedding: np.ndarray, + known_list: list[dict], + ) -> tuple[Optional[int], float]: + """Return (person_id, similarity) of the best cosine-similarity match, or (None, 0.0).""" + if not known_list: + return None, 0.0 + best_sim = 0.0 + best_id = None + for entry in known_list: + sim = float(np.dot(embedding, entry["embedding"])) + if sim > best_sim: + best_sim = sim + best_id = entry["person_id"] + if best_sim >= _SIMILARITY_THRESHOLD: + return best_id, best_sim + return None, best_sim + + # ------------------------------------------------------------------ + # Clustering + # ------------------------------------------------------------------ + + def _cluster_sync(self, embeddings: list[dict], threshold: float) -> list[list[int]]: + """ + Union-find clustering of face embeddings by cosine similarity. + + O(n²) memory — suitable for up to ~10k faces (1000 faces ≈ 4 MB float32). + Returns list of detection-id lists, one per cluster with ≥ 2 members. + """ + n = len(embeddings) + parent = list(range(n)) + + def find(x: int) -> int: + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(x: int, y: int) -> None: + px, py = find(x), find(y) + if px != py: + parent[py] = px + + M = np.stack([e["embedding"] for e in embeddings]) + norms = np.linalg.norm(M, axis=1, keepdims=True) + M_norm = M / (norms + 1e-8) + sim_matrix = M_norm @ M_norm.T + + pairs = np.argwhere(sim_matrix >= threshold) + for i, j in pairs: + if i < j: + union(int(i), int(j)) + + groups: dict[int, list[int]] = {} + for idx, e in enumerate(embeddings): + root = find(idx) + groups.setdefault(root, []).append(e["id"]) + return [ids for ids in groups.values() if len(ids) >= 2] + + async def cluster_unidentified_faces(self, threshold: float = 0.45) -> list[list[int]]: + """ + Cluster all unidentified detections by embedding similarity and persist groups to face_db. + Clears existing groups before recomputing. + """ + if not self.available: + return [] + import face_db + embeddings = face_db.get_unidentified_embeddings() + if len(embeddings) < 2: + face_db.clear_all_groups() + return [] + loop = asyncio.get_event_loop() + groups = await loop.run_in_executor( + self._executor, self._cluster_sync, embeddings, threshold + ) + face_db.clear_all_groups() + for det_ids in groups: + gid = face_db.create_group(threshold) + for det_id in det_ids: + face_db.assign_detection_to_group(det_id, gid) + return groups + + def _assign_to_nearest_group_sync(self, embedding: np.ndarray) -> int | None: + """ + Compare embedding against existing group centroids and return the best matching group_id, + or None if no group exceeds its threshold. + Fast enough to call synchronously (< 50 groups × < 50 members). + """ + import face_db + groups = face_db.get_all_group_embeddings_with_threshold() + if not groups: + return None + norm_emb = embedding / (np.linalg.norm(embedding) + 1e-8) + best_gid: int | None = None + best_sim = -1.0 + for g in groups: + M = np.stack(g["embeddings"]) + norms = np.linalg.norm(M, axis=1, keepdims=True) + M_norm = M / (norms + 1e-8) + mean_sim = float(np.mean(M_norm @ norm_emb)) + if mean_sim >= g["threshold"] and mean_sim > best_sim: + best_sim = mean_sim + best_gid = g["group_id"] + return best_gid + + # ------------------------------------------------------------------ + # High-level pipelines + # ------------------------------------------------------------------ + + async def scan_input_image(self, source_id: int, image_bytes: bytes) -> list[ScanResult]: + """Detect faces in an input image, auto-link if known, store to face_db.""" + if not self.available: + return [] + import face_db + faces = await self.detect(image_bytes) + if not faces: + return [] + known = face_db.get_all_embeddings() + persons_cache: dict[int, str] = {p["id"]: p["name"] for p in face_db.list_persons()} + results = [] + for face in faces: + person_id, _ = self.find_best_match(face.embedding, known) + person_name = persons_cache.get(person_id) if person_id is not None else None + det_id = face_db.insert_detection( + source_type="input", + source_id=source_id, + embedding=face.embedding, + bbox=face.bbox, + frame_index=0, + face_index=face.face_index, + person_id=person_id, + ) + if person_id is None: + gid = self._assign_to_nearest_group_sync(face.embedding) + if gid is not None: + face_db.assign_detection_to_group(det_id, gid) + results.append(ScanResult( + detection_id=det_id, + face_index=face.face_index, + bbox=face.bbox, + matched_person_id=person_id, + matched_person_name=person_name, + )) + return results + + async def scan_output_image(self, source_id: int, image_bytes: bytes) -> list[ScanResult]: + """Detect faces in a generated output image. Silent background scan.""" + if not self.available: + return [] + import face_db + faces = await self.detect(image_bytes) + if not faces: + return [] + known = face_db.get_all_embeddings() + persons_cache: dict[int, str] = {p["id"]: p["name"] for p in face_db.list_persons()} + results = [] + for face in faces: + person_id, _ = self.find_best_match(face.embedding, known) + det_id = face_db.insert_detection( + source_type="output", + source_id=source_id, + embedding=None, # discard; saves space; rescan fills on demand + bbox=face.bbox, + frame_index=0, + face_index=face.face_index, + person_id=person_id, + ) + if person_id is None: + gid = self._assign_to_nearest_group_sync(face.embedding) + if gid is not None: + face_db.assign_detection_to_group(det_id, gid) + if person_id is not None: + person_name = persons_cache.get(person_id) + results.append(ScanResult( + detection_id=det_id, + face_index=face.face_index, + bbox=face.bbox, + matched_person_id=person_id, + matched_person_name=person_name, + )) + logger.info( + "Face scan [output image source_id=%d]: %d face(s) detected, %d matched", + source_id, len(faces), sum(1 for r in results if r.matched_person_id is not None), + ) + return results + + def _extract_keyframes_sync(self, video_bytes: bytes, max_frames: int = 20) -> list: + """Extract evenly-spaced keyframes from video bytes. Returns list of BGR numpy arrays.""" + import cv2 + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(video_bytes) + tmp_path = f.name + try: + cap = cv2.VideoCapture(tmp_path) + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if total <= 0: + cap.release() + return [] + n = min(max_frames, total) + indices = [int(i * total / n) for i in range(n)] + frames = [] + for idx in indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if ret: + frames.append(frame) + cap.release() + return frames + finally: + try: + os.unlink(tmp_path) + except Exception: + pass + + async def scan_video(self, source_id: int, video_bytes: bytes, max_frames: int = 20) -> list[ScanResult]: + """Detect faces across video keyframes. Silent background scan.""" + if not self.available: + return [] + import cv2 + import face_db + loop = asyncio.get_event_loop() + frames = await loop.run_in_executor( + self._executor, self._extract_keyframes_sync, video_bytes, max_frames + ) + if not frames: + return [] + known = face_db.get_all_embeddings() + persons_cache: dict[int, str] = {p["id"]: p["name"] for p in face_db.list_persons()} + results = [] + seen_det_ids: set[int] = set() + for frame_idx, frame in enumerate(frames): + _, buf = cv2.imencode(".jpg", frame) + frame_bytes = buf.tobytes() + faces = await self.detect(frame_bytes) + for face in faces: + person_id, _ = self.find_best_match(face.embedding, known) + det_id = face_db.insert_detection( + source_type="output", + source_id=source_id, + embedding=None, # discard; saves space; rescan fills on demand + bbox=face.bbox, + frame_index=frame_idx, + face_index=face.face_index, + person_id=person_id, + ) + if det_id not in seen_det_ids: + seen_det_ids.add(det_id) + if person_id is not None: + person_name = persons_cache.get(person_id) + results.append(ScanResult( + detection_id=det_id, + face_index=face.face_index, + bbox=face.bbox, + matched_person_id=person_id, + matched_person_name=person_name, + )) + logger.info( + "Face scan [output video source_id=%d]: %d frame(s), %d result(s) matched", + source_id, len(frames), len(results), + ) + return results + + async def rescan_output_embedding(self, source_id: int) -> int: + """ + Re-detect faces in a stored output image and update NULL embeddings + for existing detections by bbox proximity matching. + Returns count of detections updated. + """ + if not self.available: + return 0 + import sqlite3 + import face_db + import generation_db + conn = sqlite3.connect(str(generation_db._DB_PATH), check_same_thread=False) + conn.row_factory = sqlite3.Row + row = conn.execute( + "SELECT file_data, mime_type FROM generation_files WHERE id = ?", (source_id,) + ).fetchone() + conn.close() + if row is None: + return 0 + file_bytes = bytes(row["file_data"]) + mime = (row["mime_type"] or "").lower() + if mime.startswith("video/"): + return 0 # skip videos — too expensive for backfill + + faces = await self.detect(file_bytes) + if not faces: + return 0 + + existing = [ + d for d in face_db.get_detections_for_source("output", source_id) + if d.get("embedding") is None and d.get("bbox_json") not in (None, "{}") + ] + if not existing: + return 0 + + updated = 0 + for face in faces: + fx = (face.bbox["x1"] + face.bbox["x2"]) / 2 + fy = (face.bbox["y1"] + face.bbox["y2"]) / 2 + best_det = None + best_dist = float("inf") + for det in existing: + b = json.loads(det["bbox_json"]) + dx = fx - (b["x1"] + b["x2"]) / 2 + dy = fy - (b["y1"] + b["y2"]) / 2 + dist = (dx * dx + dy * dy) ** 0.5 + if dist < best_dist: + best_dist = dist + best_det = det + if best_det is not None and best_dist <= 50: + face_db.update_detection_embedding(best_det["id"], face.embedding) + existing = [d for d in existing if d["id"] != best_det["id"]] + updated += 1 + if best_det.get("person_id") is None: + known = face_db.get_all_embeddings() + matched_pid, _ = self.find_best_match(face.embedding, known) + if matched_pid is not None: + face_db.link_detection_to_person(best_det["id"], matched_pid) + + return updated + + # ------------------------------------------------------------------ + # Utility + # ------------------------------------------------------------------ + + def _extract_frame_at_sync( + self, video_bytes: bytes, frame_index: int, max_frames: int = 20, + suffix: str = ".mp4", + ) -> "np.ndarray | None": + """ + Re-extract the specific video frame that was used during scan_video. + + frame_index is the enumeration index (0…n-1) used by scan_video, NOT the raw + video frame number. We reconstruct the same sampling formula: + actual_frame = int(frame_index * total / n) where n = min(max_frames, total) + """ + import cv2 + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f: + f.write(video_bytes) + tmp_path = f.name + try: + cap = cv2.VideoCapture(tmp_path) + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if total <= 0: + cap.release() + return None + n = min(max_frames, total) + if frame_index >= n: + cap.release() + return None + actual_idx = int(frame_index * total / n) + cap.set(cv2.CAP_PROP_POS_FRAMES, actual_idx) + ret, frame = cap.read() + cap.release() + return frame if ret else None + except Exception: + return None + finally: + try: + os.unlink(tmp_path) + except Exception: + pass + + def get_face_crop(self, detection_id: int) -> bytes | None: + """Re-derive the face crop from the stored source image or video frame. Returns JPEG bytes or None.""" + import cv2 + import face_db + det = face_db.get_detection(detection_id) + if det is None: + return None + source_type = det["source_type"] + source_id = det["source_id"] + bbox_raw = det["bbox_json"] + if not bbox_raw: + return None + bbox = json.loads(bbox_raw) + + img = None + if source_type == "input": + from input_image_db import get_image_data + image_bytes = get_image_data(source_id) + if image_bytes is None: + return None + arr = np.frombuffer(image_bytes, dtype=np.uint8) + try: + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + except Exception: + return None + elif source_type == "output": + import sqlite3 + import generation_db + conn = sqlite3.connect(str(generation_db._DB_PATH), check_same_thread=False) + conn.row_factory = sqlite3.Row + row = conn.execute( + "SELECT file_data, mime_type FROM generation_files WHERE id = ?", (source_id,) + ).fetchone() + conn.close() + if row is None: + return None + file_bytes = bytes(row["file_data"]) + mime = (row["mime_type"] or "").lower() + if mime.startswith("video/"): + frame_index = det.get("frame_index", 0) or 0 + # Pick a matching temp-file suffix so OpenCV selects the right codec + _mime_to_ext = {"video/mp4": ".mp4", "video/webm": ".webm", + "video/avi": ".avi", "video/quicktime": ".mov"} + vsuffix = _mime_to_ext.get(mime, ".mp4") + img = self._extract_frame_at_sync(file_bytes, frame_index, suffix=vsuffix) + else: + arr = np.frombuffer(file_bytes, dtype=np.uint8) + try: + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + except Exception: + return None + + if img is None: + return None + + x1, y1, x2, y2 = bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"] + pad = 20 + h, w = img.shape[:2] + cx1 = max(0, x1 - pad) + cy1 = max(0, y1 - pad) + cx2 = min(w, x2 + pad) + cy2 = min(h, y2 + pad) + crop = img[cy1:cy2, cx1:cx2] + _, buf = cv2.imencode(".jpg", crop, [cv2.IMWRITE_JPEG_QUALITY, 85]) + return buf.tobytes() + + +# Module-level singleton +_face_service: FaceService | None = None + + +def get_face_service() -> FaceService: + global _face_service + if _face_service is None: + _face_service = FaceService() + return _face_service diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 88302da..3f117e4 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -9,6 +9,8 @@ "version": "0.1.0", "dependencies": { "@tanstack/react-query": "^5.62.0", + "framer-motion": "^12.35.0", + "lucide-react": "^0.577.0", "react": "^18.3.1", "react-dom": "^18.3.1", "react-router-dom": "^6.27.0" @@ -1729,6 +1731,33 @@ "url": "https://github.com/sponsors/rawify" } }, + "node_modules/framer-motion": { + "version": "12.35.0", + "resolved": "https://artifactory.ubisoft.org/api/npm/npm/framer-motion/-/framer-motion-12.35.0.tgz", + "integrity": "sha512-w8hghCMQ4oq10j6aZh3U2yeEQv5K69O/seDI/41PK4HtgkLrcBovUNc0ayBC3UyyU7V1mrY2yLzvYdWJX9pGZQ==", + "license": "MIT", + "dependencies": { + "motion-dom": "^12.35.0", + "motion-utils": "^12.29.2", + "tslib": "^2.4.0" + }, + "peerDependencies": { + "@emotion/is-prop-valid": "*", + "react": "^18.0.0 || ^19.0.0", + "react-dom": "^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@emotion/is-prop-valid": { + "optional": true + }, + "react": { + "optional": true + }, + "react-dom": { + "optional": true + } + } + }, "node_modules/fsevents": { "version": "2.3.3", "resolved": "https://artifactory.ubisoft.org/api/npm/npm/fsevents/-/fsevents-2.3.3.tgz", @@ -1936,6 +1965,15 @@ "yallist": "^3.0.2" } }, + "node_modules/lucide-react": { + "version": "0.577.0", + "resolved": "https://artifactory.ubisoft.org/api/npm/npm/lucide-react/-/lucide-react-0.577.0.tgz", + "integrity": "sha512-4LjoFv2eEPwYDPg/CUdBJQSDfPyzXCRrVW1X7jrx/trgxnxkHFjnVZINbzvzxjN70dxychOfg+FTYwBiS3pQ5A==", + "license": "ISC", + "peerDependencies": { + "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://artifactory.ubisoft.org/api/npm/npm/merge2/-/merge2-1.4.1.tgz", @@ -1960,6 +1998,21 @@ "node": ">=8.6" } }, + "node_modules/motion-dom": { + "version": "12.35.0", + "resolved": "https://artifactory.ubisoft.org/api/npm/npm/motion-dom/-/motion-dom-12.35.0.tgz", + "integrity": "sha512-FFMLEnIejK/zDABn+vqGVAUN4T0+3fw+cVAY8MMT65yR+j5uMuvWdd4npACWhh94OVWQs79CrBBuwOwGRZAQiA==", + "license": "MIT", + "dependencies": { + "motion-utils": "^12.29.2" + } + }, + "node_modules/motion-utils": { + "version": "12.29.2", + "resolved": "https://artifactory.ubisoft.org/api/npm/npm/motion-utils/-/motion-utils-12.29.2.tgz", + "integrity": "sha512-G3kc34H2cX2gI63RqU+cZq+zWRRPSsNIOjpdl9TN4AQwC4sgwYPl/Q/Obf/d53nOm569T0fYK+tcoSV50BWx8A==", + "license": "MIT" + }, "node_modules/ms": { "version": "2.1.3", "resolved": "https://artifactory.ubisoft.org/api/npm/npm/ms/-/ms-2.1.3.tgz", @@ -2651,6 +2704,12 @@ "dev": true, "license": "Apache-2.0" }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://artifactory.ubisoft.org/api/npm/npm/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, "node_modules/typescript": { "version": "5.9.3", "resolved": "https://artifactory.ubisoft.org/api/npm/npm/typescript/-/typescript-5.9.3.tgz", diff --git a/frontend/package.json b/frontend/package.json index 4cfa59f..cb04f5c 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -9,10 +9,12 @@ "preview": "vite preview" }, "dependencies": { + "@tanstack/react-query": "^5.62.0", + "framer-motion": "^12.35.0", + "lucide-react": "^0.577.0", "react": "^18.3.1", "react-dom": "^18.3.1", - "react-router-dom": "^6.27.0", - "@tanstack/react-query": "^5.62.0" + "react-router-dom": "^6.27.0" }, "devDependencies": { "@types/react": "^18.3.12", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index fd7cc7d..46ef102 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -12,6 +12,7 @@ import StatusPage from './pages/StatusPage' import ServerPage from './pages/ServerPage' import HistoryPage from './pages/HistoryPage' import AdminPage from './pages/AdminPage' +import FacesPage from './pages/FacesPage' import SharePage from './pages/SharePage' function RequireAuth({ children }: { children: React.ReactNode }) { @@ -59,6 +60,14 @@ export default function App() { } /> + + + + } + /> } /> } /> diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 21c3b22..4e287b6 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -93,8 +93,14 @@ export interface InputImage { filename: string is_active: number active_slot_key: string | null + detected_persons?: string[] +} +export const listInputs = (persons?: string[]) => { + const params = new URLSearchParams() + persons?.forEach(p => params.append('persons', p)) + const qs = params.toString() + return _fetch(qs ? `/api/inputs?${qs}` : '/api/inputs') } -export const listInputs = () => _fetch('/api/inputs') export const uploadInput = (file: File, slotKey = 'input_image') => { const form = new FormData() @@ -142,11 +148,44 @@ export const tailLogs = (lines = 100) => _fetch<{ lines: string[] }>(`/api/logs/tail?lines=${lines}`) // History -export const getHistory = (q?: string) => - _fetch<{ history: Array> }>(q ? `/api/history?q=${encodeURIComponent(q)}` : '/api/history') +export const getHistory = (q?: string, persons?: string[]) => { + const params = new URLSearchParams() + if (q) params.set('q', q) + persons?.forEach(p => params.append('persons', p)) + const qs = params.toString() + return _fetch<{ history: Array> }>(qs ? `/api/history?${qs}` : '/api/history') +} -export const createHistoryShare = (promptId: string) => - _fetch<{ share_token: string }>(`/api/history/${promptId}/share`, { method: 'POST' }) +export const getGenerationPersons = (promptId: string) => + _fetch<{ persons: Array<{ id: number; name: string }> }>(`/api/history/${promptId}/persons`) + +export const addGenerationPerson = (promptId: string, name: string) => + _fetch<{ person_id: number; name: string }>(`/api/history/${promptId}/persons`, { + method: 'POST', + body: JSON.stringify({ name }), + }) + +export const removeGenerationPerson = (promptId: string, personId: number) => + _fetch<{ ok: boolean }>(`/api/history/${promptId}/persons/${personId}`, { method: 'DELETE' }) + +export interface ShareOptions { + is_public?: boolean + expires_in_hours?: number + max_views?: number +} + +export interface ShareResult { + share_token: string + is_public: boolean + expires_at: string | null + max_views: number | null +} + +export const createHistoryShare = (promptId: string, options?: ShareOptions) => + _fetch(`/api/history/${promptId}/share`, { + method: 'POST', + body: JSON.stringify(options ?? {}), + }) export const revokeHistoryShare = (promptId: string) => _fetch<{ ok: boolean }>(`/api/history/${promptId}/share`, { method: 'DELETE' }) @@ -196,3 +235,165 @@ export const loadWorkflow = (filename: string) => { export const getModels = (type: 'checkpoints' | 'loras') => _fetch<{ type: string; models: string[] }>(`/api/workflow/models?type=${type}`) +// Faces +export interface PendingFace { + detection_id: number + face_index: number + bbox: Record +} + +export interface Alias { + id: number + alias: string +} + +export interface Person { + id: number + name: string + created_at: string + aliases: Alias[] + face_count: number +} + +export interface UnidentifiedDetection { + id: number + source_id: number + face_index: number + bbox_json: string + created_at: string +} + +export const listPersons = () => + _fetch<{ persons: Person[] }>('/api/faces/persons') + +export const checkPersonName = (name: string) => + _fetch<{ exists: boolean }>(`/api/faces/persons/check?name=${encodeURIComponent(name)}`) + +export const identifyFaces = ( + identifications: Array<{ detection_id: number; name: string; use_existing: boolean }> +) => + _fetch<{ identifications: Array<{ detection_id: number; person_id: number; person_name: string; is_new: boolean }>; auto_linked_count: number }>( + '/api/faces/identify', + { method: 'POST', body: JSON.stringify({ identifications }) } + ) + +export const faceCropUrl = (detectionId: number) => `/api/faces/crop/${detectionId}` + +export const getUnidentifiedDetections = (limit = 50, offset = 0) => + _fetch<{ detections: UnidentifiedDetection[]; total: number }>( + `/api/faces/detections/unidentified?limit=${limit}&offset=${offset}` + ) + +export const addPersonAlias = (personId: number, alias: string) => + _fetch<{ id: number; alias: string }>(`/api/faces/persons/${personId}/aliases`, { + method: 'POST', + body: JSON.stringify({ alias }), + }) + +export const removePersonAlias = (personId: number, aliasId: number) => + _fetch<{ ok: boolean }>(`/api/faces/persons/${personId}/aliases/${aliasId}`, { + method: 'DELETE', + }) + +export interface FaceGroup { + id: number + label: string | null + threshold: number + is_manual: boolean + created_at: string + count: number + detection_ids: number[] + preview_ids: number[] +} + +export interface GroupDetection { + id: number + source_type: 'input' | 'output' + face_index: number + created_at: string +} + +export const listFaceGroups = () => + _fetch<{ groups: FaceGroup[]; total: number }>('/api/faces/groups') + +export const getFaceGroupDetections = (groupId: number) => + _fetch<{ detections: GroupDetection[] }>(`/api/faces/groups/${groupId}/detections`) + +export const computeFaceGroups = (threshold: number) => + _fetch<{ groups_created: number; total_detections_clustered: number; threshold: number }>( + '/api/faces/groups/compute', + { method: 'POST', body: JSON.stringify({ threshold }) } + ) + +export const mergeFaceGroups = (keepId: number, discardId: number) => + _fetch<{ ok: boolean; surviving_group_id: number }>('/api/faces/groups/merge', { + method: 'POST', + body: JSON.stringify({ keep_group_id: keepId, discard_group_id: discardId }), + }) + +export const identifyFaceGroup = (groupId: number, name: string, useExisting = false) => + _fetch<{ + person_id: number; person_name: string; is_new: boolean + identified_count: number; auto_linked_count: number + }>(`/api/faces/groups/${groupId}/identify`, { + method: 'POST', + body: JSON.stringify({ name, use_existing: useExisting }), + }) + +export const removeDetectionFromGroup = (groupId: number, detectionId: number) => + _fetch<{ ok: boolean }>( + `/api/faces/groups/${groupId}/detections/${detectionId}`, + { method: 'DELETE' } + ) + +export const rescanOutputEmbeddings = () => + _fetch<{ processed: number; updated: number }>('/api/faces/rescan/outputs', { method: 'POST' }) + +export interface PersonDetection { + id: number + source_type: 'input' | 'output' + source_id: number + face_index: number + frame_index: number + bbox_json: string | null + created_at: string + identified_at: string | null +} + +export const getPerson = (personId: number) => + _fetch(`/api/faces/persons/${personId}`) + +export const getPersonDetections = (personId: number, limit = 50, offset = 0) => + _fetch<{ detections: PersonDetection[]; total: number }>( + `/api/faces/persons/${personId}/detections?limit=${limit}&offset=${offset}` + ) + +export const searchPersons = (q: string, limit = 10) => + _fetch<{ persons: Person[] }>( + `/api/faces/persons/search?q=${encodeURIComponent(q)}&limit=${limit}` + ) + +export const reassignDetection = ( + detectionId: number, + personName: string | null, + useExisting = false, +) => + _fetch<{ detection_id: number; person_id: number | null; is_new?: boolean; unidentified?: boolean }>( + `/api/faces/detections/${detectionId}/reassign`, + { method: 'POST', body: JSON.stringify({ person_name: personName, use_existing: useExisting }) } + ) + +export const renamePerson = (personId: number, name: string) => + _fetch<{ ok: boolean; person_id: number; name: string }>(`/api/faces/persons/${personId}`, { + method: 'PATCH', + body: JSON.stringify({ name }), + }) + +export const deletePerson = (personId: number) => + _fetch<{ ok: boolean }>(`/api/faces/persons/${personId}`, { method: 'DELETE' }) + +export const mergePersons = (survivorId: number, otherId: number) => + _fetch<{ ok: boolean; survivor_id: number; absorbed_id: number }>( + `/api/faces/persons/${survivorId}/merge`, + { method: 'POST', body: JSON.stringify({ other_person_id: otherId }) } + ) diff --git a/frontend/src/components/DynamicWorkflowForm.tsx b/frontend/src/components/DynamicWorkflowForm.tsx index 767b498..1ae8980 100644 --- a/frontend/src/components/DynamicWorkflowForm.tsx +++ b/frontend/src/components/DynamicWorkflowForm.tsx @@ -14,6 +14,7 @@ import { getInputMid, } from '../api/client' import LazyImage from './LazyImage' +import { X } from 'lucide-react' interface Props { /** Called when the Generate button is clicked with the current overrides */ @@ -47,20 +48,18 @@ export default function DynamicWorkflowForm({ onGenerate, lastSeed, generating, }) const { data: inputImages } = useQuery({ queryKey: ['inputs'], - queryFn: listInputs, + queryFn: () => listInputs(), }) const [localValues, setLocalValues] = useState>({}) const [randomSeeds, setRandomSeeds] = useState>({}) - const [imagePicker, setImagePicker] = useState(null) // key of slot being picked + const [imagePicker, setImagePicker] = useState(null) const [count, setCount] = useState(1) - // Sync local values from state when stateData arrives useEffect(() => { if (stateData) setLocalValues(stateData as Record) }, [stateData]) - // Update seed field when WS reports completed seed useEffect(() => { if (lastSeed != null) { setLocalValues(v => ({ ...v, seed: lastSeed })) @@ -114,7 +113,7 @@ export default function DynamicWorkflowForm({ onGenerate, lastSeed, generating, return (