""" commands/input_images.py ======================== Channel-backed input image management. Images uploaded to the designated `comfy-input` channel get a persistent "✅ Set as input" button posted by the bot — one reply per attachment so every image in a multi-image message is independently selectable. Persistent views survive bot restarts: on_ready re-registers every view stored in the SQLite database. """ from __future__ import annotations import asyncio import io import logging from pathlib import Path import discord from discord.ext import commands from image_utils import compress_to_discord_limit from input_image_db import ( activate_image_for_slot, get_all_images, upsert_image, ) logger = logging.getLogger(__name__) async def _identify_faces_discord( bot, message: discord.Message, row_id: int, image_bytes: bytes, ) -> None: """ After an input image is registered, scan it for faces and prompt the uploader to identify any that weren't auto-matched. Only the original uploader's replies are accepted (author check on wait_for). The loop runs at most 3 rounds to resolve deduplication conflicts. """ try: from face_service import get_face_service import face_db svc = get_face_service() if not svc.available: return face_db.init_db() results = await svc.scan_input_image(row_id, image_bytes) unknown = [r for r in results if r.matched_person_id is None] if not unknown: return # Build initial prompt with face crops as attachments files = [] for r in unknown: crop = svc.get_face_crop(r.detection_id) if crop: files.append(discord.File(io.BytesIO(crop), filename=f"face_{r.face_index}.jpg")) n = len(unknown) prompt_text = ( f"\U0001f50d Found {n} new face(s) in your image. " f"Reply with names in order (comma-separated): `Name1, Name2, ...`\n" f"_(or ignore to skip identification)_" ) bot_msg = await message.channel.send(prompt_text, files=files) def _check(m: discord.Message) -> bool: return ( m.reference is not None and m.reference.message_id == bot_msg.id and m.author.id == message.author.id ) pending = list(unknown) # detections still needing names for _round in range(3): try: reply = await bot.wait_for("message", check=_check, timeout=120) except asyncio.TimeoutError: return raw_names = [n.strip() for n in reply.content.split(",")] if len(raw_names) < len(pending): raw_names += [""] * (len(pending) - len(raw_names)) # Check for conflicts (name exists but user said a new name) conflicts: list[tuple[int, str]] = [] # (index in pending, name) for idx, (det, name) in enumerate(zip(pending, raw_names)): if name and face_db.person_name_exists(name): conflicts.append((idx, name)) if conflicts: conflict_lines = "\n".join( f"Face {pending[idx].face_index + 1} → `{name}`" for idx, name in conflicts ) warn_msg = await reply.reply( f"\u26a0\ufe0f These names already exist:\n{conflict_lines}\n\n" f"Reply `same` for any that should link to the **existing** person, " f"or provide a different name — one value per conflicting face " f"(comma-separated, in the same order as listed above)." ) bot_msg = warn_msg # Update check to listen for reply to the new warning message def _check_conflict(m: discord.Message) -> bool: return ( m.reference is not None and m.reference.message_id == warn_msg.id and m.author.id == message.author.id ) try: conflict_reply = await bot.wait_for( "message", check=_check_conflict, timeout=120 ) except asyncio.TimeoutError: return resolved = [v.strip() for v in conflict_reply.content.split(",")] # Apply resolved names back to the original name list for list_pos, (pending_idx, old_name) in enumerate(conflicts): if list_pos < len(resolved): val = resolved[list_pos] raw_names[pending_idx] = old_name if val.lower() == "same" else val # Apply names — skip blanks confirmed: list[str] = [] for det, name in zip(pending, raw_names): if not name: continue use_existing = face_db.person_name_exists(name) person_id, _ = face_db.get_or_create_person(name) face_db.link_detection_to_person(det.detection_id, person_id) status = "linked to existing" if use_existing else "new" confirmed.append(f"{name} ({status})") if confirmed: await reply.reply(f"\u2705 Identified: {', '.join(confirmed)}") return except Exception as exc: logger.warning("Face identification flow failed: %s", exc) IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp"} class PersistentSetInputView(discord.ui.View): """ A persistent view that survives bot restarts. One instance is created per DB row (i.e. per attachment). The button's custom_id encodes the row id so the callback can look up the exact filename to download. """ def __init__(self, bot, config, row_id: int): super().__init__(timeout=None) self._bot = bot self._config = config self._row_id = row_id btn = discord.ui.Button( label="✅ Set as input", style=discord.ButtonStyle.success, custom_id=f"set_input:{row_id}", ) btn.callback = self._set_callback self.add_item(btn) async def _set_callback(self, interaction: discord.Interaction) -> None: await interaction.response.defer(ephemeral=True) try: filename = activate_image_for_slot( self._row_id, "input_image", self._config.comfy_input_path ) self._bot.comfy.state_manager.set_override("input_image", filename) await interaction.followup.send( f"✅ Input image set to `{filename}`", ephemeral=True ) except Exception as exc: logger.exception("set_input button failed for row %s", self._row_id) await interaction.followup.send(f"❌ Error: {exc}", ephemeral=True) async def _register_attachment(bot, config, message: discord.Message, attachment: discord.Attachment) -> None: """Post a reply with the image preview, a Set-as-input button, and record it in the DB.""" logger.info("[_register_attachment] Start") original_data = await attachment.read() original_filename = attachment.filename logger.info("[_register_attachment] Reading attachment") # Compress only for the Discord re-send (8 MiB bot limit) send_data, send_filename = compress_to_discord_limit(original_data, original_filename) file = discord.File(io.BytesIO(send_data), filename=send_filename) reply = await message.channel.send(f"`{original_filename}`", file=file) # Store original quality bytes in DB row_id = upsert_image(message.id, reply.id, message.channel.id, original_filename, image_data=original_data) view = PersistentSetInputView(bot, config, row_id) bot.add_view(view, message_id=reply.id) logger.info("[_register_attachment] Done") await reply.edit(view=view) # Background face scan + optional identification prompt asyncio.create_task( _identify_faces_discord(bot, message, row_id, original_data) ) def setup_input_image_commands(bot, config=None): """Register input image commands and the on_message listener.""" @bot.listen("on_message") async def _on_input_channel_message(message: discord.Message) -> None: """Watch the comfy-input channel and attach a Set-as-input button to every image upload.""" if config is None: logger.warning("[_on_input_channel_message] Config is none") return if message.channel.id != config.comfy_input_channel_id: return if message.author.bot: return image_attachments = [ a for a in message.attachments if Path(a.filename).suffix.lower() in IMAGE_EXTENSIONS ] if not image_attachments: logger.info("[_on_input_channel_message] No image attachments") return for attachment in image_attachments: await _register_attachment(bot, config, message, attachment) try: await message.delete() except discord.Forbidden: logger.warning("Missing manage_messages permission to delete message %s", message.id) except Exception as exc: logger.warning("Could not delete message %s: %s", message.id, exc) @bot.command( name="sync-inputs", aliases=["si"], extras={"category": "Files"}, help="Scan the comfy-input channel and add 'Set as input' buttons to any untracked images.", ) async def sync_inputs_command(ctx: commands.Context) -> None: """Backfill Set-as-input buttons for images uploaded while the bot was offline.""" if config is None: await ctx.reply("Bot config is not available.", mention_author=False) return channel = bot.get_channel(config.comfy_input_channel_id) if channel is None: try: channel = await bot.fetch_channel(config.comfy_input_channel_id) except Exception as exc: await ctx.reply(f"❌ Could not access input channel: {exc}", mention_author=False) return # Track existing records as (message_id, filename) pairs existing = {(row["original_message_id"], row["filename"]) for row in get_all_images()} new_count = 0 async for message in channel.history(limit=None): if message.author.bot: continue had_new = False for attachment in message.attachments: if Path(attachment.filename).suffix.lower() not in IMAGE_EXTENSIONS: continue if (message.id, attachment.filename) in existing: continue await _register_attachment(bot, config, message, attachment) existing.add((message.id, attachment.filename)) new_count += 1 had_new = True if had_new: try: await message.delete() except Exception as exc: logger.warning("sync-inputs: could not delete message %s: %s", message.id, exc) already = len(get_all_images()) - new_count await ctx.reply( f"Synced {new_count} new image(s). {already} already known.", mention_author=False, )