""" generation_db.py ================ SQLite persistence for ComfyUI generation history and output file blobs. Two tables ---------- generation_history : one row per prompt submitted to ComfyUI generation_files : one row per output file (image / video) as a BLOB The module-level ``_DB_PATH`` is set by :func:`init_db`; all other functions use that path so callers never need to pass it around. """ from __future__ import annotations import json import secrets import sqlite3 from datetime import datetime, timezone from pathlib import Path from typing import Any _DB_PATH: Path = Path(__file__).parent / "generation_history.db" _SCHEMA = """ CREATE TABLE IF NOT EXISTS generation_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, prompt_id TEXT UNIQUE NOT NULL, source TEXT NOT NULL, user_label TEXT, overrides TEXT, seed INTEGER, created_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS generation_files ( id INTEGER PRIMARY KEY AUTOINCREMENT, generation_id INTEGER NOT NULL REFERENCES generation_history(id), filename TEXT NOT NULL, file_data BLOB NOT NULL, mime_type TEXT ); CREATE TABLE IF NOT EXISTS generation_shares ( id INTEGER PRIMARY KEY AUTOINCREMENT, share_token TEXT UNIQUE NOT NULL, prompt_id TEXT NOT NULL, owner_label TEXT NOT NULL, created_at TEXT NOT NULL, is_public INTEGER NOT NULL DEFAULT 0, expires_at TEXT, max_views INTEGER, view_count INTEGER NOT NULL DEFAULT 0 ); """ def _migrate_shares_table(conn: sqlite3.Connection) -> None: migrations = [ "ALTER TABLE generation_shares ADD COLUMN is_public INTEGER NOT NULL DEFAULT 0", "ALTER TABLE generation_shares ADD COLUMN expires_at TEXT", "ALTER TABLE generation_shares ADD COLUMN max_views INTEGER", "ALTER TABLE generation_shares ADD COLUMN view_count INTEGER NOT NULL DEFAULT 0", ] for sql in migrations: try: conn.execute(sql) except sqlite3.OperationalError: pass # column already exists conn.commit() def _connect(db_path: Path | None = None) -> sqlite3.Connection: path = db_path if db_path is not None else _DB_PATH conn = sqlite3.connect(str(path), check_same_thread=False) conn.row_factory = sqlite3.Row return conn def _detect_mime(data: bytes) -> str: """Detect MIME type from magic bytes.""" if data[:8] == b"\x89PNG\r\n\x1a\n": return "image/png" if data[:2] == b"\xff\xd8": return "image/jpeg" if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP": return "image/webp" if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"AVI ": return "video/x-msvideo" if len(data) >= 8 and data[4:8] == b"ftyp": return "video/mp4" if data[:4] == b"\x1aE\xdf\xa3": # EBML (WebM/MKV) return "video/webm" return "application/octet-stream" def init_db(db_path: Path = _DB_PATH) -> None: """Create tables if they don't exist. Accepts a path for testability.""" global _DB_PATH _DB_PATH = db_path with _connect(db_path) as conn: conn.executescript(_SCHEMA) conn.commit() _migrate_shares_table(conn) def record_generation( prompt_id: str, source: str, user_label: str | None, overrides_dict: dict[str, Any] | None, seed: int | None, ) -> int: """Insert a generation history row. Returns the auto-increment ``id``.""" overrides_json = json.dumps(overrides_dict) if overrides_dict is not None else None created_at = datetime.now(timezone.utc).isoformat() with _connect() as conn: cur = conn.execute( """ INSERT INTO generation_history (prompt_id, source, user_label, overrides, seed, created_at) VALUES (?, ?, ?, ?, ?, ?) """, (prompt_id, source, user_label, overrides_json, seed, created_at), ) conn.commit() return cur.lastrowid # type: ignore[return-value] def record_file(generation_id: int, filename: str, file_data: bytes) -> int: """Insert a file BLOB row, auto-detecting MIME type from magic bytes. Returns the row id.""" mime_type = _detect_mime(file_data) with _connect() as conn: cur = conn.execute( """ INSERT INTO generation_files (generation_id, filename, file_data, mime_type) VALUES (?, ?, ?, ?) """, (generation_id, filename, file_data, mime_type), ) conn.commit() return cur.lastrowid # type: ignore[return-value] def _rows_to_history(conn: sqlite3.Connection, rows) -> list[dict]: """Convert raw generation_history rows (with optional share_token) to dicts.""" result: list[dict] = [] for row in rows: d = dict(row) if d["overrides"]: try: d["overrides"] = json.loads(d["overrides"]) except (json.JSONDecodeError, TypeError): d["overrides"] = {} else: d["overrides"] = {} files = conn.execute( "SELECT filename FROM generation_files WHERE generation_id = ?", (d["id"],), ).fetchall() d["file_paths"] = [f["filename"] for f in files] result.append(d) return result def get_history(limit: int = 50) -> list[dict]: """Return recent generation rows (newest first) with a ``file_paths`` list.""" with _connect() as conn: rows = conn.execute( """ SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at, s.share_token, s.is_public AS share_is_public, s.expires_at AS share_expires_at, s.max_views AS share_max_views, s.view_count AS share_view_count FROM generation_history h LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label ORDER BY h.id DESC LIMIT ? """, (limit,), ).fetchall() return _rows_to_history(conn, rows) def get_history_for_user(user_label: str, limit: int = 50) -> list[dict]: """Return recent generation rows for a specific user (newest first).""" with _connect() as conn: rows = conn.execute( """ SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at, s.share_token, s.is_public AS share_is_public, s.expires_at AS share_expires_at, s.max_views AS share_max_views, s.view_count AS share_view_count FROM generation_history h LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ? WHERE h.user_label = ? ORDER BY h.id DESC LIMIT ? """, (user_label, user_label, limit), ).fetchall() return _rows_to_history(conn, rows) def get_generation(prompt_id: str) -> dict | None: """Return the generation_history row for *prompt_id*, or None.""" with _connect() as conn: row = conn.execute( "SELECT id, prompt_id, user_label FROM generation_history WHERE prompt_id = ?", (prompt_id,), ).fetchone() return dict(row) if row else None def get_generation_full(prompt_id: str) -> dict | None: """Return overrides (parsed) + seed for *prompt_id*, or None if not found.""" with _connect() as conn: row = conn.execute( "SELECT prompt_id, user_label, overrides, seed FROM generation_history WHERE prompt_id = ?", (prompt_id,), ).fetchone() if row is None: return None d = dict(row) if d["overrides"]: try: d["overrides"] = json.loads(d["overrides"]) except (json.JSONDecodeError, TypeError): d["overrides"] = {} else: d["overrides"] = {} return d def search_history_for_user(user_label: str, query: str, limit: int = 50) -> list[dict]: """Return history rows where the overrides JSON contains *query* (case-insensitive).""" with _connect() as conn: rows = conn.execute( """ SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at, s.share_token, s.is_public AS share_is_public, s.expires_at AS share_expires_at, s.max_views AS share_max_views, s.view_count AS share_view_count FROM generation_history h LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ? WHERE h.user_label = ? AND LOWER(h.overrides) LIKE LOWER(?) ORDER BY h.id DESC LIMIT ? """, (user_label, user_label, f"%{query}%", limit), ).fetchall() return _rows_to_history(conn, rows) def search_history(query: str, limit: int = 50) -> list[dict]: """Admin version: search all users' history for *query* in overrides JSON.""" with _connect() as conn: rows = conn.execute( """ SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at, s.share_token, s.is_public AS share_is_public, s.expires_at AS share_expires_at, s.max_views AS share_max_views, s.view_count AS share_view_count FROM generation_history h LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label WHERE LOWER(h.overrides) LIKE LOWER(?) ORDER BY h.id DESC LIMIT ? """, (f"%{query}%", limit), ).fetchall() return _rows_to_history(conn, rows) def _is_share_expired(share_row: dict) -> bool: """Return True if the share has passed its time or view limits.""" if share_row["expires_at"] and datetime.fromisoformat(share_row["expires_at"]) <= datetime.now(timezone.utc): return True if share_row["max_views"] is not None and share_row["view_count"] >= share_row["max_views"]: return True return False def _is_share_streaming_expired(share_row: dict) -> bool: """Expiry check for file-streaming calls (no view_count increment). Uses strict-greater-than so that files remain accessible within the same page view that just consumed the last allowed view.""" if share_row["expires_at"] and datetime.fromisoformat(share_row["expires_at"]) <= datetime.now(timezone.utc): return True if share_row["max_views"] is not None and share_row["view_count"] > share_row["max_views"]: return True return False def get_active_share_for_prompt(prompt_id: str, owner_label: str) -> dict | None: """Return the active (non-expired) share row for *prompt_id*+*owner_label*, or None. If a row exists but is already expired, it is auto-deleted and None is returned so the caller can create a new share immediately. """ with _connect() as conn: row = conn.execute( """ SELECT share_token, is_public, expires_at, max_views, view_count FROM generation_shares WHERE prompt_id = ? AND owner_label = ? """, (prompt_id, owner_label), ).fetchone() if row is None: return None d = dict(row) if _is_share_expired(d): conn.execute( "DELETE FROM generation_shares WHERE share_token = ?", (d["share_token"],), ) conn.commit() return None return d def create_share( prompt_id: str, owner_label: str, *, is_public: bool = False, expires_at: str | None = None, max_views: int | None = None, ) -> dict: """Insert a fresh share row. Returns dict with share_token, is_public, expires_at, max_views.""" token = secrets.token_urlsafe(32) created_at = datetime.now(timezone.utc).isoformat() with _connect() as conn: conn.execute( """ INSERT INTO generation_shares (share_token, prompt_id, owner_label, created_at, is_public, expires_at, max_views, view_count) VALUES (?, ?, ?, ?, ?, ?, ?, 0) """, (token, prompt_id, owner_label, created_at, int(is_public), expires_at, max_views), ) conn.commit() row = conn.execute( "SELECT share_token, is_public, expires_at, max_views FROM generation_shares WHERE share_token = ?", (token,), ).fetchone() return dict(row) def revoke_share(prompt_id: str, owner_label: str | None = None) -> bool: """Delete the share token for *prompt_id*. If *owner_label* is provided, only delete that user's share. If None (admin), delete any share for the prompt_id. Returns True if a row was deleted. """ with _connect() as conn: if owner_label is not None: cur = conn.execute( "DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?", (prompt_id, owner_label), ) else: cur = conn.execute( "DELETE FROM generation_shares WHERE prompt_id = ?", (prompt_id,), ) conn.commit() return cur.rowcount > 0 def get_share_by_token(token: str) -> dict | None: """Return generation info for a share token (incrementing view_count), or None if not found/expired.""" with _connect() as conn: row = conn.execute( """ SELECT s.share_token, s.is_public, s.expires_at, s.max_views, s.view_count, h.prompt_id, h.overrides, h.seed, h.created_at FROM generation_shares s JOIN generation_history h ON h.prompt_id = s.prompt_id WHERE s.share_token = ? """, (token,), ).fetchone() if row is None: return None d = dict(row) if _is_share_expired(d): conn.execute("DELETE FROM generation_shares WHERE share_token = ?", (token,)) conn.commit() return None # Increment view count conn.execute( "UPDATE generation_shares SET view_count = view_count + 1 WHERE share_token = ?", (token,), ) conn.commit() if d["overrides"]: try: d["overrides"] = json.loads(d["overrides"]) except (json.JSONDecodeError, TypeError): d["overrides"] = {} else: d["overrides"] = {} return d def get_share_meta(token: str) -> dict | None: """Return share metadata without incrementing view_count. Used by file-streaming endpoints.""" with _connect() as conn: row = conn.execute( """ SELECT s.share_token, s.is_public, s.expires_at, s.max_views, s.view_count, h.prompt_id, h.overrides, h.seed, h.created_at FROM generation_shares s JOIN generation_history h ON h.prompt_id = s.prompt_id WHERE s.share_token = ? """, (token,), ).fetchone() if row is None: return None d = dict(row) if _is_share_streaming_expired(d): conn.execute("DELETE FROM generation_shares WHERE share_token = ?", (token,)) conn.commit() return None if d["overrides"]: try: d["overrides"] = json.loads(d["overrides"]) except (json.JSONDecodeError, TypeError): d["overrides"] = {} else: d["overrides"] = {} return d def get_file_ids_for_prompt(prompt_id: str) -> list[int]: """Return generation_files.id values for all files belonging to prompt_id.""" with _connect() as conn: rows = conn.execute( """SELECT gf.id FROM generation_files gf JOIN generation_history gh ON gh.id = gf.generation_id WHERE gh.prompt_id = ?""", (prompt_id,), ).fetchall() return [r["id"] for r in rows] def get_generation_ids_for_file_ids(file_ids: list[int]) -> list[int]: """Return distinct generation_id values for the given generation_files row ids.""" if not file_ids: return [] placeholders = ",".join("?" * len(file_ids)) with _connect() as conn: rows = conn.execute( f"SELECT DISTINCT generation_id FROM generation_files WHERE id IN ({placeholders})", tuple(file_ids), ).fetchall() return [r["generation_id"] for r in rows] def get_file_ids_for_generation_ids(gen_ids: list[int]) -> dict[int, list[int]]: """Return {gen_id: [file_id, …]} for the given generation_history row ids.""" if not gen_ids: return {} placeholders = ",".join("?" * len(gen_ids)) with _connect() as conn: rows = conn.execute( f"SELECT generation_id, id FROM generation_files WHERE generation_id IN ({placeholders})", tuple(gen_ids), ).fetchall() result: dict[int, list[int]] = {gid: [] for gid in gen_ids} for row in rows: result[row["generation_id"]].append(row["id"]) return result def get_files(prompt_id: str) -> list[dict]: """Return all output files for *prompt_id* as ``[{filename, data, mime_type}]``.""" with _connect() as conn: gen_row = conn.execute( "SELECT id FROM generation_history WHERE prompt_id = ?", (prompt_id,), ).fetchone() if not gen_row: return [] files = conn.execute( "SELECT filename, file_data, mime_type FROM generation_files WHERE generation_id = ?", (gen_row["id"],), ).fetchall() return [ { "filename": f["filename"], "data": bytes(f["file_data"]), "mime_type": f["mime_type"], } for f in files ]