""" input_image_db.py ================= SQLite helpers for tracking Discord-channel-backed input images. The database stores one row per attachment, so a single message with multiple images produces multiple rows. The stable lookup key is the auto-increment `id`; `(original_message_id, filename)` is a unique composite constraint. """ from __future__ import annotations import sqlite3 from pathlib import Path DB_PATH = Path(__file__).parent / "input_images.db" _SCHEMA = """ CREATE TABLE IF NOT EXISTS input_images ( id INTEGER PRIMARY KEY AUTOINCREMENT, original_message_id INTEGER NOT NULL, bot_reply_id INTEGER NOT NULL, channel_id INTEGER NOT NULL, filename TEXT NOT NULL, is_active INTEGER NOT NULL DEFAULT 0, image_data BLOB, active_slot_key TEXT DEFAULT NULL, UNIQUE(original_message_id, filename) ) """ # Live migrations applied on every startup — safe if column already exists _MIGRATIONS = [ "ALTER TABLE input_images ADD COLUMN image_data BLOB", "ALTER TABLE input_images ADD COLUMN active_slot_key TEXT DEFAULT NULL", ] # Columns returned by get_image / get_all_images (excludes the potentially large BLOB) _SAFE_COLS = "id, original_message_id, bot_reply_id, channel_id, filename, is_active, active_slot_key" def _connect() -> sqlite3.Connection: conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row return conn def init_db() -> None: """Create the input_images table if it does not exist, and run column migrations.""" with _connect() as conn: conn.execute(_SCHEMA) for stmt in _MIGRATIONS: try: conn.execute(stmt) except sqlite3.OperationalError: pass # column already exists conn.commit() def upsert_image( original_message_id: int, bot_reply_id: int, channel_id: int, filename: str, image_data: bytes | None = None, ) -> int: """ Insert a new image record or update an existing one. Returns the stable row ``id`` (used as the persistent view key). When *image_data* is provided it is stored as a BLOB; on UPDATE it is only overwritten when not None. """ with _connect() as conn: existing = conn.execute( "SELECT id FROM input_images WHERE original_message_id = ? AND filename = ?", (original_message_id, filename), ).fetchone() if existing: if image_data is not None: conn.execute( "UPDATE input_images SET bot_reply_id = ?, channel_id = ?, image_data = ? WHERE id = ?", (bot_reply_id, channel_id, image_data, existing["id"]), ) else: conn.execute( "UPDATE input_images SET bot_reply_id = ?, channel_id = ? WHERE id = ?", (bot_reply_id, channel_id, existing["id"]), ) row_id = existing["id"] else: cur = conn.execute( """ INSERT INTO input_images (original_message_id, bot_reply_id, channel_id, filename, is_active, image_data) VALUES (?, ?, ?, ?, 0, ?) """, (original_message_id, bot_reply_id, channel_id, filename, image_data), ) row_id = cur.lastrowid conn.commit() return row_id def get_image_data(row_id: int) -> bytes | None: """Return the raw image bytes for a row, or None if the row is missing or has no data.""" with _connect() as conn: row = conn.execute( "SELECT image_data FROM input_images WHERE id = ?", (row_id,), ).fetchone() if row is None: return None return row["image_data"] def activate_image_for_slot(row_id: int, slot_key: str, comfy_input_path: str) -> str: """ Write the stored image bytes to ``{comfy_input_path}/ttb_{slot_key}{ext}`` and record the slot assignment in the DB. Returns the basename of the written file (e.g. ``ttb_input_image.jpg``). Raises ``ValueError`` if the row has no image_data (user must re-upload). """ data = get_image_data(row_id) if data is None: raise ValueError( f"No image data stored for row {row_id}. Re-upload the image to backfill." ) row = get_image(row_id) if row is None: raise ValueError(f"No DB record for row id {row_id}") ext = Path(row["filename"]).suffix # e.g. ".jpg" dest_name = f"ttb_{slot_key}{ext}" input_path = Path(comfy_input_path) input_path.mkdir(parents=True, exist_ok=True) # Remove any existing file for this slot (may have a different extension) for old in input_path.glob(f"ttb_{slot_key}.*"): try: old.unlink() except Exception: pass (input_path / dest_name).write_bytes(data) # Update DB: clear slot from previous holder, then assign to this row with _connect() as conn: conn.execute( "UPDATE input_images SET active_slot_key = NULL WHERE active_slot_key = ?", (slot_key,), ) conn.execute( "UPDATE input_images SET active_slot_key = ? WHERE id = ?", (slot_key, row_id), ) conn.commit() return dest_name def deactivate_image_slot(slot_key: str, comfy_input_path: str) -> None: """ Remove the ``ttb_{slot_key}.*`` file from the ComfyUI input folder and clear the matching DB column. Safe no-op if nothing is active for that slot. """ input_path = Path(comfy_input_path) for old in input_path.glob(f"ttb_{slot_key}.*"): try: old.unlink() except Exception: pass with _connect() as conn: conn.execute( "UPDATE input_images SET active_slot_key = NULL WHERE active_slot_key = ?", (slot_key,), ) conn.commit() def set_active(row_id: int) -> None: """Mark one image as active and clear the active flag on all others.""" with _connect() as conn: conn.execute("UPDATE input_images SET is_active = 0") conn.execute( "UPDATE input_images SET is_active = 1 WHERE id = ?", (row_id,), ) conn.commit() def get_image(row_id: int) -> dict | None: """Return a single image row by its auto-increment id (excluding image_data), or None.""" with _connect() as conn: row = conn.execute( f"SELECT {_SAFE_COLS} FROM input_images WHERE id = ?", (row_id,), ).fetchone() return dict(row) if row else None def get_all_images() -> list[dict]: """Return all image rows as a list of dicts (excluding image_data).""" with _connect() as conn: rows = conn.execute( f"SELECT {_SAFE_COLS} FROM input_images" ).fetchall() return [dict(r) for r in rows] def delete_image(row_id: int, comfy_input_path: str | None = None) -> None: """ Remove an image record from the database. If the image is currently active for a slot and *comfy_input_path* is provided, the corresponding ``ttb_{slot_key}.*`` file is also deleted. """ row = get_image(row_id) if row and row.get("active_slot_key") and comfy_input_path: deactivate_image_slot(row["active_slot_key"], comfy_input_path) with _connect() as conn: conn.execute( "DELETE FROM input_images WHERE id = ?", (row_id,), ) conn.commit()