Files
comfy-discord-web/face_service.py
Khoa (Revenovich) Tran Gia 6004b000a7 manual submit
2026-03-07 21:49:16 +07:00

566 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
face_service.py
===============
FaceService: wrapper around insightface for face detection and recognition.
Runs CPU-bound work in a ThreadPoolExecutor(max_workers=1).
Falls back gracefully if insightface is not installed (available=False).
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional
import numpy as np
logger = logging.getLogger(__name__)
try:
from insightface.app import FaceAnalysis as _FaceAnalysis
_INSIGHTFACE_AVAILABLE = True
except ImportError:
_FaceAnalysis = None # type: ignore
_INSIGHTFACE_AVAILABLE = False
_SIMILARITY_THRESHOLD = 0.4
@dataclass
class DetectedFace:
face_index: int
bbox: dict # {x1, y1, x2, y2}
embedding: np.ndarray
crop_bytes: bytes # JPEG bytes of the face crop
@dataclass
class ScanResult:
detection_id: int
face_index: int
bbox: dict
matched_person_id: Optional[int]
matched_person_name: Optional[str]
class FaceService:
available: bool
def __init__(self) -> None:
self.available = _INSIGHTFACE_AVAILABLE
self._executor = ThreadPoolExecutor(max_workers=1)
self._app = None
if self.available:
try:
self._app = _FaceAnalysis(providers=["CPUExecutionProvider"])
self._app.prepare(ctx_id=0, det_size=(640, 640))
logger.info("FaceService: insightface ready")
except Exception as exc:
logger.warning("FaceService: failed to init insightface: %s", exc)
self.available = False
# ------------------------------------------------------------------
# Low-level detection
# ------------------------------------------------------------------
def _detect_sync(self, image_bytes: bytes) -> list[DetectedFace]:
"""CPU-bound: detect faces in image bytes."""
import cv2
arr = np.frombuffer(image_bytes, dtype=np.uint8)
try:
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
except Exception as exc:
logger.warning("FaceService: cv2.imdecode failed: %s", exc)
return []
if img is None:
return []
try:
faces = self._app.get(img)
except Exception as exc:
logger.warning("FaceService: face detection failed: %s", exc)
return []
results = []
for i, face in enumerate(faces):
x1, y1, x2, y2 = (int(v) for v in face.bbox)
bbox = {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
emb = face.normed_embedding.astype(np.float32)
# Crop with padding
pad = 20
h, w = img.shape[:2]
cx1 = max(0, x1 - pad)
cy1 = max(0, y1 - pad)
cx2 = min(w, x2 + pad)
cy2 = min(h, y2 + pad)
crop = img[cy1:cy2, cx1:cx2]
_, buf = cv2.imencode(".jpg", crop, [cv2.IMWRITE_JPEG_QUALITY, 85])
crop_bytes = buf.tobytes()
results.append(DetectedFace(
face_index=i,
bbox=bbox,
embedding=emb,
crop_bytes=crop_bytes,
))
return results
async def detect(self, image_bytes: bytes) -> list[DetectedFace]:
"""Async face detection wrapper."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self._executor, self._detect_sync, image_bytes)
# ------------------------------------------------------------------
# Matching
# ------------------------------------------------------------------
def find_best_match(
self,
embedding: np.ndarray,
known_list: list[dict],
) -> tuple[Optional[int], float]:
"""Return (person_id, similarity) of the best cosine-similarity match, or (None, 0.0)."""
if not known_list:
return None, 0.0
best_sim = 0.0
best_id = None
for entry in known_list:
sim = float(np.dot(embedding, entry["embedding"]))
if sim > best_sim:
best_sim = sim
best_id = entry["person_id"]
if best_sim >= _SIMILARITY_THRESHOLD:
return best_id, best_sim
return None, best_sim
# ------------------------------------------------------------------
# Clustering
# ------------------------------------------------------------------
def _cluster_sync(self, embeddings: list[dict], threshold: float) -> list[list[int]]:
"""
Union-find clustering of face embeddings by cosine similarity.
O(n²) memory — suitable for up to ~10k faces (1000 faces ≈ 4 MB float32).
Returns list of detection-id lists, one per cluster with ≥ 2 members.
"""
n = len(embeddings)
parent = list(range(n))
def find(x: int) -> int:
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(x: int, y: int) -> None:
px, py = find(x), find(y)
if px != py:
parent[py] = px
M = np.stack([e["embedding"] for e in embeddings])
norms = np.linalg.norm(M, axis=1, keepdims=True)
M_norm = M / (norms + 1e-8)
sim_matrix = M_norm @ M_norm.T
pairs = np.argwhere(sim_matrix >= threshold)
for i, j in pairs:
if i < j:
union(int(i), int(j))
groups: dict[int, list[int]] = {}
for idx, e in enumerate(embeddings):
root = find(idx)
groups.setdefault(root, []).append(e["id"])
return [ids for ids in groups.values() if len(ids) >= 2]
async def cluster_unidentified_faces(self, threshold: float = 0.45) -> list[list[int]]:
"""
Cluster all unidentified detections by embedding similarity and persist groups to face_db.
Clears existing groups before recomputing.
"""
if not self.available:
return []
import face_db
embeddings = face_db.get_unidentified_embeddings()
if len(embeddings) < 2:
face_db.clear_all_groups()
return []
loop = asyncio.get_event_loop()
groups = await loop.run_in_executor(
self._executor, self._cluster_sync, embeddings, threshold
)
face_db.clear_all_groups()
for det_ids in groups:
gid = face_db.create_group(threshold)
for det_id in det_ids:
face_db.assign_detection_to_group(det_id, gid)
return groups
def _assign_to_nearest_group_sync(self, embedding: np.ndarray) -> int | None:
"""
Compare embedding against existing group centroids and return the best matching group_id,
or None if no group exceeds its threshold.
Fast enough to call synchronously (< 50 groups × < 50 members).
"""
import face_db
groups = face_db.get_all_group_embeddings_with_threshold()
if not groups:
return None
norm_emb = embedding / (np.linalg.norm(embedding) + 1e-8)
best_gid: int | None = None
best_sim = -1.0
for g in groups:
M = np.stack(g["embeddings"])
norms = np.linalg.norm(M, axis=1, keepdims=True)
M_norm = M / (norms + 1e-8)
mean_sim = float(np.mean(M_norm @ norm_emb))
if mean_sim >= g["threshold"] and mean_sim > best_sim:
best_sim = mean_sim
best_gid = g["group_id"]
return best_gid
# ------------------------------------------------------------------
# High-level pipelines
# ------------------------------------------------------------------
async def scan_input_image(self, source_id: int, image_bytes: bytes) -> list[ScanResult]:
"""Detect faces in an input image, auto-link if known, store to face_db."""
if not self.available:
return []
import face_db
faces = await self.detect(image_bytes)
if not faces:
return []
known = face_db.get_all_embeddings()
persons_cache: dict[int, str] = {p["id"]: p["name"] for p in face_db.list_persons()}
results = []
for face in faces:
person_id, _ = self.find_best_match(face.embedding, known)
person_name = persons_cache.get(person_id) if person_id is not None else None
det_id = face_db.insert_detection(
source_type="input",
source_id=source_id,
embedding=face.embedding,
bbox=face.bbox,
frame_index=0,
face_index=face.face_index,
person_id=person_id,
)
if person_id is None:
gid = self._assign_to_nearest_group_sync(face.embedding)
if gid is not None:
face_db.assign_detection_to_group(det_id, gid)
results.append(ScanResult(
detection_id=det_id,
face_index=face.face_index,
bbox=face.bbox,
matched_person_id=person_id,
matched_person_name=person_name,
))
return results
async def scan_output_image(self, source_id: int, image_bytes: bytes) -> list[ScanResult]:
"""Detect faces in a generated output image. Silent background scan."""
if not self.available:
return []
import face_db
faces = await self.detect(image_bytes)
if not faces:
return []
known = face_db.get_all_embeddings()
persons_cache: dict[int, str] = {p["id"]: p["name"] for p in face_db.list_persons()}
results = []
for face in faces:
person_id, _ = self.find_best_match(face.embedding, known)
det_id = face_db.insert_detection(
source_type="output",
source_id=source_id,
embedding=None, # discard; saves space; rescan fills on demand
bbox=face.bbox,
frame_index=0,
face_index=face.face_index,
person_id=person_id,
)
if person_id is None:
gid = self._assign_to_nearest_group_sync(face.embedding)
if gid is not None:
face_db.assign_detection_to_group(det_id, gid)
if person_id is not None:
person_name = persons_cache.get(person_id)
results.append(ScanResult(
detection_id=det_id,
face_index=face.face_index,
bbox=face.bbox,
matched_person_id=person_id,
matched_person_name=person_name,
))
logger.info(
"Face scan [output image source_id=%d]: %d face(s) detected, %d matched",
source_id, len(faces), sum(1 for r in results if r.matched_person_id is not None),
)
return results
def _extract_keyframes_sync(self, video_bytes: bytes, max_frames: int = 20) -> list:
"""Extract evenly-spaced keyframes from video bytes. Returns list of BGR numpy arrays."""
import cv2
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
f.write(video_bytes)
tmp_path = f.name
try:
cap = cv2.VideoCapture(tmp_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total <= 0:
cap.release()
return []
n = min(max_frames, total)
indices = [int(i * total / n) for i in range(n)]
frames = []
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frames.append(frame)
cap.release()
return frames
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
async def scan_video(self, source_id: int, video_bytes: bytes, max_frames: int = 20) -> list[ScanResult]:
"""Detect faces across video keyframes. Silent background scan."""
if not self.available:
return []
import cv2
import face_db
loop = asyncio.get_event_loop()
frames = await loop.run_in_executor(
self._executor, self._extract_keyframes_sync, video_bytes, max_frames
)
if not frames:
return []
known = face_db.get_all_embeddings()
persons_cache: dict[int, str] = {p["id"]: p["name"] for p in face_db.list_persons()}
results = []
seen_det_ids: set[int] = set()
for frame_idx, frame in enumerate(frames):
_, buf = cv2.imencode(".jpg", frame)
frame_bytes = buf.tobytes()
faces = await self.detect(frame_bytes)
for face in faces:
person_id, _ = self.find_best_match(face.embedding, known)
det_id = face_db.insert_detection(
source_type="output",
source_id=source_id,
embedding=None, # discard; saves space; rescan fills on demand
bbox=face.bbox,
frame_index=frame_idx,
face_index=face.face_index,
person_id=person_id,
)
if det_id not in seen_det_ids:
seen_det_ids.add(det_id)
if person_id is not None:
person_name = persons_cache.get(person_id)
results.append(ScanResult(
detection_id=det_id,
face_index=face.face_index,
bbox=face.bbox,
matched_person_id=person_id,
matched_person_name=person_name,
))
logger.info(
"Face scan [output video source_id=%d]: %d frame(s), %d result(s) matched",
source_id, len(frames), len(results),
)
return results
async def rescan_output_embedding(self, source_id: int) -> int:
"""
Re-detect faces in a stored output image and update NULL embeddings
for existing detections by bbox proximity matching.
Returns count of detections updated.
"""
if not self.available:
return 0
import sqlite3
import face_db
import generation_db
conn = sqlite3.connect(str(generation_db._DB_PATH), check_same_thread=False)
conn.row_factory = sqlite3.Row
row = conn.execute(
"SELECT file_data, mime_type FROM generation_files WHERE id = ?", (source_id,)
).fetchone()
conn.close()
if row is None:
return 0
file_bytes = bytes(row["file_data"])
mime = (row["mime_type"] or "").lower()
if mime.startswith("video/"):
return 0 # skip videos — too expensive for backfill
faces = await self.detect(file_bytes)
if not faces:
return 0
existing = [
d for d in face_db.get_detections_for_source("output", source_id)
if d.get("embedding") is None and d.get("bbox_json") not in (None, "{}")
]
if not existing:
return 0
updated = 0
for face in faces:
fx = (face.bbox["x1"] + face.bbox["x2"]) / 2
fy = (face.bbox["y1"] + face.bbox["y2"]) / 2
best_det = None
best_dist = float("inf")
for det in existing:
b = json.loads(det["bbox_json"])
dx = fx - (b["x1"] + b["x2"]) / 2
dy = fy - (b["y1"] + b["y2"]) / 2
dist = (dx * dx + dy * dy) ** 0.5
if dist < best_dist:
best_dist = dist
best_det = det
if best_det is not None and best_dist <= 50:
face_db.update_detection_embedding(best_det["id"], face.embedding)
existing = [d for d in existing if d["id"] != best_det["id"]]
updated += 1
if best_det.get("person_id") is None:
known = face_db.get_all_embeddings()
matched_pid, _ = self.find_best_match(face.embedding, known)
if matched_pid is not None:
face_db.link_detection_to_person(best_det["id"], matched_pid)
return updated
# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------
def _extract_frame_at_sync(
self, video_bytes: bytes, frame_index: int, max_frames: int = 20,
suffix: str = ".mp4",
) -> "np.ndarray | None":
"""
Re-extract the specific video frame that was used during scan_video.
frame_index is the enumeration index (0…n-1) used by scan_video, NOT the raw
video frame number. We reconstruct the same sampling formula:
actual_frame = int(frame_index * total / n) where n = min(max_frames, total)
"""
import cv2
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
f.write(video_bytes)
tmp_path = f.name
try:
cap = cv2.VideoCapture(tmp_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total <= 0:
cap.release()
return None
n = min(max_frames, total)
if frame_index >= n:
cap.release()
return None
actual_idx = int(frame_index * total / n)
cap.set(cv2.CAP_PROP_POS_FRAMES, actual_idx)
ret, frame = cap.read()
cap.release()
return frame if ret else None
except Exception:
return None
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
def get_face_crop(self, detection_id: int) -> bytes | None:
"""Re-derive the face crop from the stored source image or video frame. Returns JPEG bytes or None."""
import cv2
import face_db
det = face_db.get_detection(detection_id)
if det is None:
return None
source_type = det["source_type"]
source_id = det["source_id"]
bbox_raw = det["bbox_json"]
if not bbox_raw:
return None
bbox = json.loads(bbox_raw)
img = None
if source_type == "input":
from input_image_db import get_image_data
image_bytes = get_image_data(source_id)
if image_bytes is None:
return None
arr = np.frombuffer(image_bytes, dtype=np.uint8)
try:
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
except Exception:
return None
elif source_type == "output":
import sqlite3
import generation_db
conn = sqlite3.connect(str(generation_db._DB_PATH), check_same_thread=False)
conn.row_factory = sqlite3.Row
row = conn.execute(
"SELECT file_data, mime_type FROM generation_files WHERE id = ?", (source_id,)
).fetchone()
conn.close()
if row is None:
return None
file_bytes = bytes(row["file_data"])
mime = (row["mime_type"] or "").lower()
if mime.startswith("video/"):
frame_index = det.get("frame_index", 0) or 0
# Pick a matching temp-file suffix so OpenCV selects the right codec
_mime_to_ext = {"video/mp4": ".mp4", "video/webm": ".webm",
"video/avi": ".avi", "video/quicktime": ".mov"}
vsuffix = _mime_to_ext.get(mime, ".mp4")
img = self._extract_frame_at_sync(file_bytes, frame_index, suffix=vsuffix)
else:
arr = np.frombuffer(file_bytes, dtype=np.uint8)
try:
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
except Exception:
return None
if img is None:
return None
x1, y1, x2, y2 = bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]
pad = 20
h, w = img.shape[:2]
cx1 = max(0, x1 - pad)
cy1 = max(0, y1 - pad)
cx2 = min(w, x2 + pad)
cy2 = min(h, y2 + pad)
crop = img[cy1:cy2, cx1:cx2]
_, buf = cv2.imencode(".jpg", crop, [cv2.IMWRITE_JPEG_QUALITY, 85])
return buf.tobytes()
# Module-level singleton
_face_service: FaceService | None = None
def get_face_service() -> FaceService:
global _face_service
if _face_service is None:
_face_service = FaceService()
return _face_service