306 lines
11 KiB
Python
306 lines
11 KiB
Python
"""
|
|
commands/input_images.py
|
|
========================
|
|
|
|
Channel-backed input image management.
|
|
|
|
Images uploaded to the designated `comfy-input` channel get a persistent
|
|
"✅ Set as input" button posted by the bot — one reply per attachment so
|
|
every image in a multi-image message is independently selectable.
|
|
|
|
Persistent views survive bot restarts: on_ready re-registers every view
|
|
stored in the SQLite database.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import io
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import discord
|
|
from discord.ext import commands
|
|
|
|
from image_utils import compress_to_discord_limit
|
|
from input_image_db import (
|
|
activate_image_for_slot,
|
|
get_all_images,
|
|
upsert_image,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _identify_faces_discord(
|
|
bot,
|
|
message: discord.Message,
|
|
row_id: int,
|
|
image_bytes: bytes,
|
|
) -> None:
|
|
"""
|
|
After an input image is registered, scan it for faces and prompt the
|
|
uploader to identify any that weren't auto-matched.
|
|
|
|
Only the original uploader's replies are accepted (author check on wait_for).
|
|
The loop runs at most 3 rounds to resolve deduplication conflicts.
|
|
"""
|
|
try:
|
|
from face_service import get_face_service
|
|
import face_db
|
|
|
|
svc = get_face_service()
|
|
if not svc.available:
|
|
return
|
|
|
|
face_db.init_db()
|
|
results = await svc.scan_input_image(row_id, image_bytes)
|
|
unknown = [r for r in results if r.matched_person_id is None]
|
|
if not unknown:
|
|
return
|
|
|
|
# Build initial prompt with face crops as attachments
|
|
files = []
|
|
for r in unknown:
|
|
crop = svc.get_face_crop(r.detection_id)
|
|
if crop:
|
|
files.append(discord.File(io.BytesIO(crop), filename=f"face_{r.face_index}.jpg"))
|
|
|
|
n = len(unknown)
|
|
prompt_text = (
|
|
f"\U0001f50d Found {n} new face(s) in your image. "
|
|
f"Reply with names in order (comma-separated): `Name1, Name2, ...`\n"
|
|
f"_(or ignore to skip identification)_"
|
|
)
|
|
bot_msg = await message.channel.send(prompt_text, files=files)
|
|
|
|
def _check(m: discord.Message) -> bool:
|
|
return (
|
|
m.reference is not None
|
|
and m.reference.message_id == bot_msg.id
|
|
and m.author.id == message.author.id
|
|
)
|
|
|
|
pending = list(unknown) # detections still needing names
|
|
|
|
for _round in range(3):
|
|
try:
|
|
reply = await bot.wait_for("message", check=_check, timeout=120)
|
|
except asyncio.TimeoutError:
|
|
return
|
|
|
|
raw_names = [n.strip() for n in reply.content.split(",")]
|
|
if len(raw_names) < len(pending):
|
|
raw_names += [""] * (len(pending) - len(raw_names))
|
|
|
|
# Check for conflicts (name exists but user said a new name)
|
|
conflicts: list[tuple[int, str]] = [] # (index in pending, name)
|
|
for idx, (det, name) in enumerate(zip(pending, raw_names)):
|
|
if name and face_db.person_name_exists(name):
|
|
conflicts.append((idx, name))
|
|
|
|
if conflicts:
|
|
conflict_lines = "\n".join(
|
|
f"Face {pending[idx].face_index + 1} → `{name}`"
|
|
for idx, name in conflicts
|
|
)
|
|
warn_msg = await reply.reply(
|
|
f"\u26a0\ufe0f These names already exist:\n{conflict_lines}\n\n"
|
|
f"Reply `same` for any that should link to the **existing** person, "
|
|
f"or provide a different name — one value per conflicting face "
|
|
f"(comma-separated, in the same order as listed above)."
|
|
)
|
|
bot_msg = warn_msg
|
|
|
|
# Update check to listen for reply to the new warning message
|
|
def _check_conflict(m: discord.Message) -> bool:
|
|
return (
|
|
m.reference is not None
|
|
and m.reference.message_id == warn_msg.id
|
|
and m.author.id == message.author.id
|
|
)
|
|
|
|
try:
|
|
conflict_reply = await bot.wait_for(
|
|
"message", check=_check_conflict, timeout=120
|
|
)
|
|
except asyncio.TimeoutError:
|
|
return
|
|
|
|
resolved = [v.strip() for v in conflict_reply.content.split(",")]
|
|
# Apply resolved names back to the original name list
|
|
for list_pos, (pending_idx, old_name) in enumerate(conflicts):
|
|
if list_pos < len(resolved):
|
|
val = resolved[list_pos]
|
|
raw_names[pending_idx] = old_name if val.lower() == "same" else val
|
|
|
|
# Apply names — skip blanks
|
|
confirmed: list[str] = []
|
|
for det, name in zip(pending, raw_names):
|
|
if not name:
|
|
continue
|
|
use_existing = face_db.person_name_exists(name)
|
|
person_id, _ = face_db.get_or_create_person(name)
|
|
face_db.link_detection_to_person(det.detection_id, person_id)
|
|
status = "linked to existing" if use_existing else "new"
|
|
confirmed.append(f"{name} ({status})")
|
|
|
|
if confirmed:
|
|
await reply.reply(f"\u2705 Identified: {', '.join(confirmed)}")
|
|
return
|
|
|
|
except Exception as exc:
|
|
logger.warning("Face identification flow failed: %s", exc)
|
|
|
|
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp"}
|
|
|
|
|
|
|
|
class PersistentSetInputView(discord.ui.View):
|
|
"""
|
|
A persistent view that survives bot restarts.
|
|
|
|
One instance is created per DB row (i.e. per attachment).
|
|
The button's custom_id encodes the row id so the callback can look
|
|
up the exact filename to download.
|
|
"""
|
|
|
|
def __init__(self, bot, config, row_id: int):
|
|
super().__init__(timeout=None)
|
|
self._bot = bot
|
|
self._config = config
|
|
self._row_id = row_id
|
|
|
|
btn = discord.ui.Button(
|
|
label="✅ Set as input",
|
|
style=discord.ButtonStyle.success,
|
|
custom_id=f"set_input:{row_id}",
|
|
)
|
|
btn.callback = self._set_callback
|
|
self.add_item(btn)
|
|
|
|
async def _set_callback(self, interaction: discord.Interaction) -> None:
|
|
await interaction.response.defer(ephemeral=True)
|
|
try:
|
|
filename = activate_image_for_slot(
|
|
self._row_id, "input_image", self._config.comfy_input_path
|
|
)
|
|
self._bot.comfy.state_manager.set_override("input_image", filename)
|
|
await interaction.followup.send(
|
|
f"✅ Input image set to `{filename}`", ephemeral=True
|
|
)
|
|
except Exception as exc:
|
|
logger.exception("set_input button failed for row %s", self._row_id)
|
|
await interaction.followup.send(f"❌ Error: {exc}", ephemeral=True)
|
|
|
|
|
|
async def _register_attachment(bot, config, message: discord.Message, attachment: discord.Attachment) -> None:
|
|
"""Post a reply with the image preview, a Set-as-input button, and record it in the DB."""
|
|
logger.info("[_register_attachment] Start")
|
|
original_data = await attachment.read()
|
|
original_filename = attachment.filename
|
|
logger.info("[_register_attachment] Reading attachment")
|
|
|
|
# Compress only for the Discord re-send (8 MiB bot limit)
|
|
send_data, send_filename = compress_to_discord_limit(original_data, original_filename)
|
|
|
|
file = discord.File(io.BytesIO(send_data), filename=send_filename)
|
|
reply = await message.channel.send(f"`{original_filename}`", file=file)
|
|
|
|
# Store original quality bytes in DB
|
|
row_id = upsert_image(message.id, reply.id, message.channel.id, original_filename, image_data=original_data)
|
|
view = PersistentSetInputView(bot, config, row_id)
|
|
bot.add_view(view, message_id=reply.id)
|
|
logger.info("[_register_attachment] Done")
|
|
await reply.edit(view=view)
|
|
|
|
# Background face scan + optional identification prompt
|
|
asyncio.create_task(
|
|
_identify_faces_discord(bot, message, row_id, original_data)
|
|
)
|
|
|
|
|
|
def setup_input_image_commands(bot, config=None):
|
|
"""Register input image commands and the on_message listener."""
|
|
|
|
@bot.listen("on_message")
|
|
async def _on_input_channel_message(message: discord.Message) -> None:
|
|
"""Watch the comfy-input channel and attach a Set-as-input button to every image upload."""
|
|
if config is None:
|
|
logger.warning("[_on_input_channel_message] Config is none")
|
|
return
|
|
if message.channel.id != config.comfy_input_channel_id:
|
|
return
|
|
if message.author.bot:
|
|
return
|
|
|
|
image_attachments = [
|
|
a for a in message.attachments
|
|
if Path(a.filename).suffix.lower() in IMAGE_EXTENSIONS
|
|
]
|
|
if not image_attachments:
|
|
logger.info("[_on_input_channel_message] No image attachments")
|
|
return
|
|
|
|
for attachment in image_attachments:
|
|
await _register_attachment(bot, config, message, attachment)
|
|
|
|
try:
|
|
await message.delete()
|
|
except discord.Forbidden:
|
|
logger.warning("Missing manage_messages permission to delete message %s", message.id)
|
|
except Exception as exc:
|
|
logger.warning("Could not delete message %s: %s", message.id, exc)
|
|
|
|
@bot.command(
|
|
name="sync-inputs",
|
|
aliases=["si"],
|
|
extras={"category": "Files"},
|
|
help="Scan the comfy-input channel and add 'Set as input' buttons to any untracked images.",
|
|
)
|
|
async def sync_inputs_command(ctx: commands.Context) -> None:
|
|
"""Backfill Set-as-input buttons for images uploaded while the bot was offline."""
|
|
if config is None:
|
|
await ctx.reply("Bot config is not available.", mention_author=False)
|
|
return
|
|
|
|
channel = bot.get_channel(config.comfy_input_channel_id)
|
|
if channel is None:
|
|
try:
|
|
channel = await bot.fetch_channel(config.comfy_input_channel_id)
|
|
except Exception as exc:
|
|
await ctx.reply(f"❌ Could not access input channel: {exc}", mention_author=False)
|
|
return
|
|
|
|
# Track existing records as (message_id, filename) pairs
|
|
existing = {(row["original_message_id"], row["filename"]) for row in get_all_images()}
|
|
|
|
new_count = 0
|
|
async for message in channel.history(limit=None):
|
|
if message.author.bot:
|
|
continue
|
|
|
|
had_new = False
|
|
for attachment in message.attachments:
|
|
if Path(attachment.filename).suffix.lower() not in IMAGE_EXTENSIONS:
|
|
continue
|
|
if (message.id, attachment.filename) in existing:
|
|
continue
|
|
|
|
await _register_attachment(bot, config, message, attachment)
|
|
existing.add((message.id, attachment.filename))
|
|
new_count += 1
|
|
had_new = True
|
|
|
|
if had_new:
|
|
try:
|
|
await message.delete()
|
|
except Exception as exc:
|
|
logger.warning("sync-inputs: could not delete message %s: %s", message.id, exc)
|
|
|
|
already = len(get_all_images()) - new_count
|
|
await ctx.reply(
|
|
f"Synced {new_count} new image(s). {already} already known.",
|
|
mention_author=False,
|
|
)
|