manual submit
This commit is contained in:
@@ -14,6 +14,7 @@ stored in the SQLite database.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@@ -30,6 +31,127 @@ from input_image_db import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _identify_faces_discord(
|
||||
bot,
|
||||
message: discord.Message,
|
||||
row_id: int,
|
||||
image_bytes: bytes,
|
||||
) -> None:
|
||||
"""
|
||||
After an input image is registered, scan it for faces and prompt the
|
||||
uploader to identify any that weren't auto-matched.
|
||||
|
||||
Only the original uploader's replies are accepted (author check on wait_for).
|
||||
The loop runs at most 3 rounds to resolve deduplication conflicts.
|
||||
"""
|
||||
try:
|
||||
from face_service import get_face_service
|
||||
import face_db
|
||||
|
||||
svc = get_face_service()
|
||||
if not svc.available:
|
||||
return
|
||||
|
||||
face_db.init_db()
|
||||
results = await svc.scan_input_image(row_id, image_bytes)
|
||||
unknown = [r for r in results if r.matched_person_id is None]
|
||||
if not unknown:
|
||||
return
|
||||
|
||||
# Build initial prompt with face crops as attachments
|
||||
files = []
|
||||
for r in unknown:
|
||||
crop = svc.get_face_crop(r.detection_id)
|
||||
if crop:
|
||||
files.append(discord.File(io.BytesIO(crop), filename=f"face_{r.face_index}.jpg"))
|
||||
|
||||
n = len(unknown)
|
||||
prompt_text = (
|
||||
f"\U0001f50d Found {n} new face(s) in your image. "
|
||||
f"Reply with names in order (comma-separated): `Name1, Name2, ...`\n"
|
||||
f"_(or ignore to skip identification)_"
|
||||
)
|
||||
bot_msg = await message.channel.send(prompt_text, files=files)
|
||||
|
||||
def _check(m: discord.Message) -> bool:
|
||||
return (
|
||||
m.reference is not None
|
||||
and m.reference.message_id == bot_msg.id
|
||||
and m.author.id == message.author.id
|
||||
)
|
||||
|
||||
pending = list(unknown) # detections still needing names
|
||||
|
||||
for _round in range(3):
|
||||
try:
|
||||
reply = await bot.wait_for("message", check=_check, timeout=120)
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
|
||||
raw_names = [n.strip() for n in reply.content.split(",")]
|
||||
if len(raw_names) < len(pending):
|
||||
raw_names += [""] * (len(pending) - len(raw_names))
|
||||
|
||||
# Check for conflicts (name exists but user said a new name)
|
||||
conflicts: list[tuple[int, str]] = [] # (index in pending, name)
|
||||
for idx, (det, name) in enumerate(zip(pending, raw_names)):
|
||||
if name and face_db.person_name_exists(name):
|
||||
conflicts.append((idx, name))
|
||||
|
||||
if conflicts:
|
||||
conflict_lines = "\n".join(
|
||||
f"Face {pending[idx].face_index + 1} → `{name}`"
|
||||
for idx, name in conflicts
|
||||
)
|
||||
warn_msg = await reply.reply(
|
||||
f"\u26a0\ufe0f These names already exist:\n{conflict_lines}\n\n"
|
||||
f"Reply `same` for any that should link to the **existing** person, "
|
||||
f"or provide a different name — one value per conflicting face "
|
||||
f"(comma-separated, in the same order as listed above)."
|
||||
)
|
||||
bot_msg = warn_msg
|
||||
|
||||
# Update check to listen for reply to the new warning message
|
||||
def _check_conflict(m: discord.Message) -> bool:
|
||||
return (
|
||||
m.reference is not None
|
||||
and m.reference.message_id == warn_msg.id
|
||||
and m.author.id == message.author.id
|
||||
)
|
||||
|
||||
try:
|
||||
conflict_reply = await bot.wait_for(
|
||||
"message", check=_check_conflict, timeout=120
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
|
||||
resolved = [v.strip() for v in conflict_reply.content.split(",")]
|
||||
# Apply resolved names back to the original name list
|
||||
for list_pos, (pending_idx, old_name) in enumerate(conflicts):
|
||||
if list_pos < len(resolved):
|
||||
val = resolved[list_pos]
|
||||
raw_names[pending_idx] = old_name if val.lower() == "same" else val
|
||||
|
||||
# Apply names — skip blanks
|
||||
confirmed: list[str] = []
|
||||
for det, name in zip(pending, raw_names):
|
||||
if not name:
|
||||
continue
|
||||
use_existing = face_db.person_name_exists(name)
|
||||
person_id, _ = face_db.get_or_create_person(name)
|
||||
face_db.link_detection_to_person(det.detection_id, person_id)
|
||||
status = "linked to existing" if use_existing else "new"
|
||||
confirmed.append(f"{name} ({status})")
|
||||
|
||||
if confirmed:
|
||||
await reply.reply(f"\u2705 Identified: {', '.join(confirmed)}")
|
||||
return
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Face identification flow failed: %s", exc)
|
||||
|
||||
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp"}
|
||||
|
||||
|
||||
@@ -92,6 +214,11 @@ async def _register_attachment(bot, config, message: discord.Message, attachment
|
||||
logger.info("[_register_attachment] Done")
|
||||
await reply.edit(view=view)
|
||||
|
||||
# Background face scan + optional identification prompt
|
||||
asyncio.create_task(
|
||||
_identify_faces_discord(bot, message, row_id, original_data)
|
||||
)
|
||||
|
||||
|
||||
def setup_input_image_commands(bot, config=None):
|
||||
"""Register input image commands and the on_message listener."""
|
||||
|
||||
Reference in New Issue
Block a user