566 lines
21 KiB
Python
566 lines
21 KiB
Python
"""
|
||
face_service.py
|
||
===============
|
||
|
||
FaceService: wrapper around insightface for face detection and recognition.
|
||
|
||
Runs CPU-bound work in a ThreadPoolExecutor(max_workers=1).
|
||
Falls back gracefully if insightface is not installed (available=False).
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import tempfile
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from dataclasses import dataclass
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
try:
|
||
from insightface.app import FaceAnalysis as _FaceAnalysis
|
||
_INSIGHTFACE_AVAILABLE = True
|
||
except ImportError:
|
||
_FaceAnalysis = None # type: ignore
|
||
_INSIGHTFACE_AVAILABLE = False
|
||
|
||
_SIMILARITY_THRESHOLD = 0.4
|
||
|
||
|
||
@dataclass
|
||
class DetectedFace:
|
||
face_index: int
|
||
bbox: dict # {x1, y1, x2, y2}
|
||
embedding: np.ndarray
|
||
crop_bytes: bytes # JPEG bytes of the face crop
|
||
|
||
|
||
@dataclass
|
||
class ScanResult:
|
||
detection_id: int
|
||
face_index: int
|
||
bbox: dict
|
||
matched_person_id: Optional[int]
|
||
matched_person_name: Optional[str]
|
||
|
||
|
||
class FaceService:
|
||
available: bool
|
||
|
||
def __init__(self) -> None:
|
||
self.available = _INSIGHTFACE_AVAILABLE
|
||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||
self._app = None
|
||
if self.available:
|
||
try:
|
||
self._app = _FaceAnalysis(providers=["CPUExecutionProvider"])
|
||
self._app.prepare(ctx_id=0, det_size=(640, 640))
|
||
logger.info("FaceService: insightface ready")
|
||
except Exception as exc:
|
||
logger.warning("FaceService: failed to init insightface: %s", exc)
|
||
self.available = False
|
||
|
||
# ------------------------------------------------------------------
|
||
# Low-level detection
|
||
# ------------------------------------------------------------------
|
||
|
||
def _detect_sync(self, image_bytes: bytes) -> list[DetectedFace]:
|
||
"""CPU-bound: detect faces in image bytes."""
|
||
import cv2
|
||
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||
try:
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
except Exception as exc:
|
||
logger.warning("FaceService: cv2.imdecode failed: %s", exc)
|
||
return []
|
||
if img is None:
|
||
return []
|
||
try:
|
||
faces = self._app.get(img)
|
||
except Exception as exc:
|
||
logger.warning("FaceService: face detection failed: %s", exc)
|
||
return []
|
||
|
||
results = []
|
||
for i, face in enumerate(faces):
|
||
x1, y1, x2, y2 = (int(v) for v in face.bbox)
|
||
bbox = {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
|
||
emb = face.normed_embedding.astype(np.float32)
|
||
|
||
# Crop with padding
|
||
pad = 20
|
||
h, w = img.shape[:2]
|
||
cx1 = max(0, x1 - pad)
|
||
cy1 = max(0, y1 - pad)
|
||
cx2 = min(w, x2 + pad)
|
||
cy2 = min(h, y2 + pad)
|
||
crop = img[cy1:cy2, cx1:cx2]
|
||
_, buf = cv2.imencode(".jpg", crop, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||
crop_bytes = buf.tobytes()
|
||
|
||
results.append(DetectedFace(
|
||
face_index=i,
|
||
bbox=bbox,
|
||
embedding=emb,
|
||
crop_bytes=crop_bytes,
|
||
))
|
||
return results
|
||
|
||
async def detect(self, image_bytes: bytes) -> list[DetectedFace]:
|
||
"""Async face detection wrapper."""
|
||
loop = asyncio.get_event_loop()
|
||
return await loop.run_in_executor(self._executor, self._detect_sync, image_bytes)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Matching
|
||
# ------------------------------------------------------------------
|
||
|
||
def find_best_match(
|
||
self,
|
||
embedding: np.ndarray,
|
||
known_list: list[dict],
|
||
) -> tuple[Optional[int], float]:
|
||
"""Return (person_id, similarity) of the best cosine-similarity match, or (None, 0.0)."""
|
||
if not known_list:
|
||
return None, 0.0
|
||
best_sim = 0.0
|
||
best_id = None
|
||
for entry in known_list:
|
||
sim = float(np.dot(embedding, entry["embedding"]))
|
||
if sim > best_sim:
|
||
best_sim = sim
|
||
best_id = entry["person_id"]
|
||
if best_sim >= _SIMILARITY_THRESHOLD:
|
||
return best_id, best_sim
|
||
return None, best_sim
|
||
|
||
# ------------------------------------------------------------------
|
||
# Clustering
|
||
# ------------------------------------------------------------------
|
||
|
||
def _cluster_sync(self, embeddings: list[dict], threshold: float) -> list[list[int]]:
|
||
"""
|
||
Union-find clustering of face embeddings by cosine similarity.
|
||
|
||
O(n²) memory — suitable for up to ~10k faces (1000 faces ≈ 4 MB float32).
|
||
Returns list of detection-id lists, one per cluster with ≥ 2 members.
|
||
"""
|
||
n = len(embeddings)
|
||
parent = list(range(n))
|
||
|
||
def find(x: int) -> int:
|
||
while parent[x] != x:
|
||
parent[x] = parent[parent[x]]
|
||
x = parent[x]
|
||
return x
|
||
|
||
def union(x: int, y: int) -> None:
|
||
px, py = find(x), find(y)
|
||
if px != py:
|
||
parent[py] = px
|
||
|
||
M = np.stack([e["embedding"] for e in embeddings])
|
||
norms = np.linalg.norm(M, axis=1, keepdims=True)
|
||
M_norm = M / (norms + 1e-8)
|
||
sim_matrix = M_norm @ M_norm.T
|
||
|
||
pairs = np.argwhere(sim_matrix >= threshold)
|
||
for i, j in pairs:
|
||
if i < j:
|
||
union(int(i), int(j))
|
||
|
||
groups: dict[int, list[int]] = {}
|
||
for idx, e in enumerate(embeddings):
|
||
root = find(idx)
|
||
groups.setdefault(root, []).append(e["id"])
|
||
return [ids for ids in groups.values() if len(ids) >= 2]
|
||
|
||
async def cluster_unidentified_faces(self, threshold: float = 0.45) -> list[list[int]]:
|
||
"""
|
||
Cluster all unidentified detections by embedding similarity and persist groups to face_db.
|
||
Clears existing groups before recomputing.
|
||
"""
|
||
if not self.available:
|
||
return []
|
||
import face_db
|
||
embeddings = face_db.get_unidentified_embeddings()
|
||
if len(embeddings) < 2:
|
||
face_db.clear_all_groups()
|
||
return []
|
||
loop = asyncio.get_event_loop()
|
||
groups = await loop.run_in_executor(
|
||
self._executor, self._cluster_sync, embeddings, threshold
|
||
)
|
||
face_db.clear_all_groups()
|
||
for det_ids in groups:
|
||
gid = face_db.create_group(threshold)
|
||
for det_id in det_ids:
|
||
face_db.assign_detection_to_group(det_id, gid)
|
||
return groups
|
||
|
||
def _assign_to_nearest_group_sync(self, embedding: np.ndarray) -> int | None:
|
||
"""
|
||
Compare embedding against existing group centroids and return the best matching group_id,
|
||
or None if no group exceeds its threshold.
|
||
Fast enough to call synchronously (< 50 groups × < 50 members).
|
||
"""
|
||
import face_db
|
||
groups = face_db.get_all_group_embeddings_with_threshold()
|
||
if not groups:
|
||
return None
|
||
norm_emb = embedding / (np.linalg.norm(embedding) + 1e-8)
|
||
best_gid: int | None = None
|
||
best_sim = -1.0
|
||
for g in groups:
|
||
M = np.stack(g["embeddings"])
|
||
norms = np.linalg.norm(M, axis=1, keepdims=True)
|
||
M_norm = M / (norms + 1e-8)
|
||
mean_sim = float(np.mean(M_norm @ norm_emb))
|
||
if mean_sim >= g["threshold"] and mean_sim > best_sim:
|
||
best_sim = mean_sim
|
||
best_gid = g["group_id"]
|
||
return best_gid
|
||
|
||
# ------------------------------------------------------------------
|
||
# High-level pipelines
|
||
# ------------------------------------------------------------------
|
||
|
||
async def scan_input_image(self, source_id: int, image_bytes: bytes) -> list[ScanResult]:
|
||
"""Detect faces in an input image, auto-link if known, store to face_db."""
|
||
if not self.available:
|
||
return []
|
||
import face_db
|
||
faces = await self.detect(image_bytes)
|
||
if not faces:
|
||
return []
|
||
known = face_db.get_all_embeddings()
|
||
persons_cache: dict[int, str] = {p["id"]: p["name"] for p in face_db.list_persons()}
|
||
results = []
|
||
for face in faces:
|
||
person_id, _ = self.find_best_match(face.embedding, known)
|
||
person_name = persons_cache.get(person_id) if person_id is not None else None
|
||
det_id = face_db.insert_detection(
|
||
source_type="input",
|
||
source_id=source_id,
|
||
embedding=face.embedding,
|
||
bbox=face.bbox,
|
||
frame_index=0,
|
||
face_index=face.face_index,
|
||
person_id=person_id,
|
||
)
|
||
if person_id is None:
|
||
gid = self._assign_to_nearest_group_sync(face.embedding)
|
||
if gid is not None:
|
||
face_db.assign_detection_to_group(det_id, gid)
|
||
results.append(ScanResult(
|
||
detection_id=det_id,
|
||
face_index=face.face_index,
|
||
bbox=face.bbox,
|
||
matched_person_id=person_id,
|
||
matched_person_name=person_name,
|
||
))
|
||
return results
|
||
|
||
async def scan_output_image(self, source_id: int, image_bytes: bytes) -> list[ScanResult]:
|
||
"""Detect faces in a generated output image. Silent background scan."""
|
||
if not self.available:
|
||
return []
|
||
import face_db
|
||
faces = await self.detect(image_bytes)
|
||
if not faces:
|
||
return []
|
||
known = face_db.get_all_embeddings()
|
||
persons_cache: dict[int, str] = {p["id"]: p["name"] for p in face_db.list_persons()}
|
||
results = []
|
||
for face in faces:
|
||
person_id, _ = self.find_best_match(face.embedding, known)
|
||
det_id = face_db.insert_detection(
|
||
source_type="output",
|
||
source_id=source_id,
|
||
embedding=None, # discard; saves space; rescan fills on demand
|
||
bbox=face.bbox,
|
||
frame_index=0,
|
||
face_index=face.face_index,
|
||
person_id=person_id,
|
||
)
|
||
if person_id is None:
|
||
gid = self._assign_to_nearest_group_sync(face.embedding)
|
||
if gid is not None:
|
||
face_db.assign_detection_to_group(det_id, gid)
|
||
if person_id is not None:
|
||
person_name = persons_cache.get(person_id)
|
||
results.append(ScanResult(
|
||
detection_id=det_id,
|
||
face_index=face.face_index,
|
||
bbox=face.bbox,
|
||
matched_person_id=person_id,
|
||
matched_person_name=person_name,
|
||
))
|
||
logger.info(
|
||
"Face scan [output image source_id=%d]: %d face(s) detected, %d matched",
|
||
source_id, len(faces), sum(1 for r in results if r.matched_person_id is not None),
|
||
)
|
||
return results
|
||
|
||
def _extract_keyframes_sync(self, video_bytes: bytes, max_frames: int = 20) -> list:
|
||
"""Extract evenly-spaced keyframes from video bytes. Returns list of BGR numpy arrays."""
|
||
import cv2
|
||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
||
f.write(video_bytes)
|
||
tmp_path = f.name
|
||
try:
|
||
cap = cv2.VideoCapture(tmp_path)
|
||
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
if total <= 0:
|
||
cap.release()
|
||
return []
|
||
n = min(max_frames, total)
|
||
indices = [int(i * total / n) for i in range(n)]
|
||
frames = []
|
||
for idx in indices:
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||
ret, frame = cap.read()
|
||
if ret:
|
||
frames.append(frame)
|
||
cap.release()
|
||
return frames
|
||
finally:
|
||
try:
|
||
os.unlink(tmp_path)
|
||
except Exception:
|
||
pass
|
||
|
||
async def scan_video(self, source_id: int, video_bytes: bytes, max_frames: int = 20) -> list[ScanResult]:
|
||
"""Detect faces across video keyframes. Silent background scan."""
|
||
if not self.available:
|
||
return []
|
||
import cv2
|
||
import face_db
|
||
loop = asyncio.get_event_loop()
|
||
frames = await loop.run_in_executor(
|
||
self._executor, self._extract_keyframes_sync, video_bytes, max_frames
|
||
)
|
||
if not frames:
|
||
return []
|
||
known = face_db.get_all_embeddings()
|
||
persons_cache: dict[int, str] = {p["id"]: p["name"] for p in face_db.list_persons()}
|
||
results = []
|
||
seen_det_ids: set[int] = set()
|
||
for frame_idx, frame in enumerate(frames):
|
||
_, buf = cv2.imencode(".jpg", frame)
|
||
frame_bytes = buf.tobytes()
|
||
faces = await self.detect(frame_bytes)
|
||
for face in faces:
|
||
person_id, _ = self.find_best_match(face.embedding, known)
|
||
det_id = face_db.insert_detection(
|
||
source_type="output",
|
||
source_id=source_id,
|
||
embedding=None, # discard; saves space; rescan fills on demand
|
||
bbox=face.bbox,
|
||
frame_index=frame_idx,
|
||
face_index=face.face_index,
|
||
person_id=person_id,
|
||
)
|
||
if det_id not in seen_det_ids:
|
||
seen_det_ids.add(det_id)
|
||
if person_id is not None:
|
||
person_name = persons_cache.get(person_id)
|
||
results.append(ScanResult(
|
||
detection_id=det_id,
|
||
face_index=face.face_index,
|
||
bbox=face.bbox,
|
||
matched_person_id=person_id,
|
||
matched_person_name=person_name,
|
||
))
|
||
logger.info(
|
||
"Face scan [output video source_id=%d]: %d frame(s), %d result(s) matched",
|
||
source_id, len(frames), len(results),
|
||
)
|
||
return results
|
||
|
||
async def rescan_output_embedding(self, source_id: int) -> int:
|
||
"""
|
||
Re-detect faces in a stored output image and update NULL embeddings
|
||
for existing detections by bbox proximity matching.
|
||
Returns count of detections updated.
|
||
"""
|
||
if not self.available:
|
||
return 0
|
||
import sqlite3
|
||
import face_db
|
||
import generation_db
|
||
conn = sqlite3.connect(str(generation_db._DB_PATH), check_same_thread=False)
|
||
conn.row_factory = sqlite3.Row
|
||
row = conn.execute(
|
||
"SELECT file_data, mime_type FROM generation_files WHERE id = ?", (source_id,)
|
||
).fetchone()
|
||
conn.close()
|
||
if row is None:
|
||
return 0
|
||
file_bytes = bytes(row["file_data"])
|
||
mime = (row["mime_type"] or "").lower()
|
||
if mime.startswith("video/"):
|
||
return 0 # skip videos — too expensive for backfill
|
||
|
||
faces = await self.detect(file_bytes)
|
||
if not faces:
|
||
return 0
|
||
|
||
existing = [
|
||
d for d in face_db.get_detections_for_source("output", source_id)
|
||
if d.get("embedding") is None and d.get("bbox_json") not in (None, "{}")
|
||
]
|
||
if not existing:
|
||
return 0
|
||
|
||
updated = 0
|
||
for face in faces:
|
||
fx = (face.bbox["x1"] + face.bbox["x2"]) / 2
|
||
fy = (face.bbox["y1"] + face.bbox["y2"]) / 2
|
||
best_det = None
|
||
best_dist = float("inf")
|
||
for det in existing:
|
||
b = json.loads(det["bbox_json"])
|
||
dx = fx - (b["x1"] + b["x2"]) / 2
|
||
dy = fy - (b["y1"] + b["y2"]) / 2
|
||
dist = (dx * dx + dy * dy) ** 0.5
|
||
if dist < best_dist:
|
||
best_dist = dist
|
||
best_det = det
|
||
if best_det is not None and best_dist <= 50:
|
||
face_db.update_detection_embedding(best_det["id"], face.embedding)
|
||
existing = [d for d in existing if d["id"] != best_det["id"]]
|
||
updated += 1
|
||
if best_det.get("person_id") is None:
|
||
known = face_db.get_all_embeddings()
|
||
matched_pid, _ = self.find_best_match(face.embedding, known)
|
||
if matched_pid is not None:
|
||
face_db.link_detection_to_person(best_det["id"], matched_pid)
|
||
|
||
return updated
|
||
|
||
# ------------------------------------------------------------------
|
||
# Utility
|
||
# ------------------------------------------------------------------
|
||
|
||
def _extract_frame_at_sync(
|
||
self, video_bytes: bytes, frame_index: int, max_frames: int = 20,
|
||
suffix: str = ".mp4",
|
||
) -> "np.ndarray | None":
|
||
"""
|
||
Re-extract the specific video frame that was used during scan_video.
|
||
|
||
frame_index is the enumeration index (0…n-1) used by scan_video, NOT the raw
|
||
video frame number. We reconstruct the same sampling formula:
|
||
actual_frame = int(frame_index * total / n) where n = min(max_frames, total)
|
||
"""
|
||
import cv2
|
||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
||
f.write(video_bytes)
|
||
tmp_path = f.name
|
||
try:
|
||
cap = cv2.VideoCapture(tmp_path)
|
||
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
if total <= 0:
|
||
cap.release()
|
||
return None
|
||
n = min(max_frames, total)
|
||
if frame_index >= n:
|
||
cap.release()
|
||
return None
|
||
actual_idx = int(frame_index * total / n)
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, actual_idx)
|
||
ret, frame = cap.read()
|
||
cap.release()
|
||
return frame if ret else None
|
||
except Exception:
|
||
return None
|
||
finally:
|
||
try:
|
||
os.unlink(tmp_path)
|
||
except Exception:
|
||
pass
|
||
|
||
def get_face_crop(self, detection_id: int) -> bytes | None:
|
||
"""Re-derive the face crop from the stored source image or video frame. Returns JPEG bytes or None."""
|
||
import cv2
|
||
import face_db
|
||
det = face_db.get_detection(detection_id)
|
||
if det is None:
|
||
return None
|
||
source_type = det["source_type"]
|
||
source_id = det["source_id"]
|
||
bbox_raw = det["bbox_json"]
|
||
if not bbox_raw:
|
||
return None
|
||
bbox = json.loads(bbox_raw)
|
||
|
||
img = None
|
||
if source_type == "input":
|
||
from input_image_db import get_image_data
|
||
image_bytes = get_image_data(source_id)
|
||
if image_bytes is None:
|
||
return None
|
||
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||
try:
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
except Exception:
|
||
return None
|
||
elif source_type == "output":
|
||
import sqlite3
|
||
import generation_db
|
||
conn = sqlite3.connect(str(generation_db._DB_PATH), check_same_thread=False)
|
||
conn.row_factory = sqlite3.Row
|
||
row = conn.execute(
|
||
"SELECT file_data, mime_type FROM generation_files WHERE id = ?", (source_id,)
|
||
).fetchone()
|
||
conn.close()
|
||
if row is None:
|
||
return None
|
||
file_bytes = bytes(row["file_data"])
|
||
mime = (row["mime_type"] or "").lower()
|
||
if mime.startswith("video/"):
|
||
frame_index = det.get("frame_index", 0) or 0
|
||
# Pick a matching temp-file suffix so OpenCV selects the right codec
|
||
_mime_to_ext = {"video/mp4": ".mp4", "video/webm": ".webm",
|
||
"video/avi": ".avi", "video/quicktime": ".mov"}
|
||
vsuffix = _mime_to_ext.get(mime, ".mp4")
|
||
img = self._extract_frame_at_sync(file_bytes, frame_index, suffix=vsuffix)
|
||
else:
|
||
arr = np.frombuffer(file_bytes, dtype=np.uint8)
|
||
try:
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
except Exception:
|
||
return None
|
||
|
||
if img is None:
|
||
return None
|
||
|
||
x1, y1, x2, y2 = bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]
|
||
pad = 20
|
||
h, w = img.shape[:2]
|
||
cx1 = max(0, x1 - pad)
|
||
cy1 = max(0, y1 - pad)
|
||
cx2 = min(w, x2 + pad)
|
||
cy2 = min(h, y2 + pad)
|
||
crop = img[cy1:cy2, cx1:cx2]
|
||
_, buf = cv2.imencode(".jpg", crop, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||
return buf.tobytes()
|
||
|
||
|
||
# Module-level singleton
|
||
_face_service: FaceService | None = None
|
||
|
||
|
||
def get_face_service() -> FaceService:
|
||
global _face_service
|
||
if _face_service is None:
|
||
_face_service = FaceService()
|
||
return _face_service
|