132 lines
4.7 KiB
Python
132 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
sync_faces.py
|
|
=============
|
|
|
|
One-time backfill script: scan existing input_images and generation_files
|
|
for faces and store detections in faces.db.
|
|
|
|
Usage:
|
|
python sync_faces.py [--dry-run] [--input-only] [--output-only]
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import sqlite3
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def main(
|
|
dry_run: bool, input_only: bool, output_only: bool,
|
|
cluster: bool, cluster_threshold: float,
|
|
) -> None:
|
|
import face_db
|
|
from face_service import get_face_service
|
|
|
|
face_db.init_db()
|
|
svc = get_face_service()
|
|
|
|
if not svc.available:
|
|
logger.error(
|
|
"insightface is not available. "
|
|
"Install: pip install insightface onnxruntime opencv-python"
|
|
)
|
|
return
|
|
|
|
import generation_db
|
|
import input_image_db
|
|
|
|
total_faces = 0
|
|
total_matched = 0
|
|
total_unidentified = 0
|
|
|
|
# Scan input images
|
|
if not output_only:
|
|
logger.info("Scanning input images…")
|
|
conn = sqlite3.connect(str(input_image_db.DB_PATH), check_same_thread=False)
|
|
conn.row_factory = sqlite3.Row
|
|
rows = conn.execute(
|
|
"SELECT id, image_data FROM input_images WHERE image_data IS NOT NULL"
|
|
).fetchall()
|
|
conn.close()
|
|
|
|
for row in rows:
|
|
row_id = row["id"]
|
|
image_bytes = bytes(row["image_data"])
|
|
logger.info(" input image id=%d (%d bytes)", row_id, len(image_bytes))
|
|
if not dry_run:
|
|
try:
|
|
results = await svc.scan_input_image(row_id, image_bytes)
|
|
for r in results:
|
|
total_faces += 1
|
|
if r.matched_person_id is not None:
|
|
total_matched += 1
|
|
else:
|
|
total_unidentified += 1
|
|
except Exception as exc:
|
|
logger.warning(" Failed for input id=%d: %s", row_id, exc)
|
|
|
|
# Scan generated output files
|
|
if not input_only:
|
|
logger.info("Scanning generation output files…")
|
|
conn = sqlite3.connect(str(generation_db._DB_PATH), check_same_thread=False)
|
|
conn.row_factory = sqlite3.Row
|
|
rows = conn.execute(
|
|
"SELECT id, file_data, mime_type FROM generation_files"
|
|
).fetchall()
|
|
conn.close()
|
|
|
|
for row in rows:
|
|
file_id = row["id"]
|
|
file_data = bytes(row["file_data"])
|
|
mime_type = row["mime_type"] or ""
|
|
logger.info(
|
|
" output file id=%d mime=%s (%d bytes)", file_id, mime_type, len(file_data)
|
|
)
|
|
if not dry_run:
|
|
try:
|
|
if mime_type.startswith("image/"):
|
|
results = await svc.scan_output_image(file_id, file_data)
|
|
total_faces += len(results)
|
|
total_matched += sum(1 for r in results if r.matched_person_id is not None)
|
|
elif mime_type.startswith("video/"):
|
|
results = await svc.scan_video(file_id, file_data)
|
|
total_faces += len(results)
|
|
total_matched += sum(1 for r in results if r.matched_person_id is not None)
|
|
except Exception as exc:
|
|
logger.warning(" Failed for output id=%d: %s", file_id, exc)
|
|
|
|
if dry_run:
|
|
logger.info("Dry run — no data written.")
|
|
else:
|
|
logger.info(
|
|
"Done. %d faces detected, %d matched to known persons, %d unidentified",
|
|
total_faces,
|
|
total_matched,
|
|
total_unidentified,
|
|
)
|
|
if cluster:
|
|
logger.info("Clustering unidentified faces (threshold=%.2f)…", cluster_threshold)
|
|
groups = await svc.cluster_unidentified_faces(cluster_threshold)
|
|
logger.info("Clustering: %d groups created", len(groups))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Backfill face detections for existing media")
|
|
parser.add_argument("--dry-run", action="store_true")
|
|
parser.add_argument("--input-only", action="store_true")
|
|
parser.add_argument("--output-only", action="store_true")
|
|
parser.add_argument("--cluster", action="store_true", help="Run auto-clustering after scanning")
|
|
parser.add_argument("--cluster-threshold", type=float, default=0.45, metavar="T",
|
|
help="Cosine similarity threshold for clustering (default: 0.45)")
|
|
args = parser.parse_args()
|
|
asyncio.run(main(
|
|
args.dry_run, args.input_only, args.output_only,
|
|
args.cluster, args.cluster_threshold,
|
|
))
|