""" generation_db.py ================ SQLite persistence for ComfyUI generation history and output file blobs. Two tables ---------- generation_history : one row per prompt submitted to ComfyUI generation_files : one row per output file (image / video) as a BLOB The module-level ``_DB_PATH`` is set by :func:`init_db`; all other functions use that path so callers never need to pass it around. """ from __future__ import annotations import json import secrets import sqlite3 from datetime import datetime, timezone from pathlib import Path from typing import Any _DB_PATH: Path = Path(__file__).parent / "generation_history.db" _SCHEMA = """ CREATE TABLE IF NOT EXISTS generation_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, prompt_id TEXT UNIQUE NOT NULL, source TEXT NOT NULL, user_label TEXT, overrides TEXT, seed INTEGER, created_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS generation_files ( id INTEGER PRIMARY KEY AUTOINCREMENT, generation_id INTEGER NOT NULL REFERENCES generation_history(id), filename TEXT NOT NULL, file_data BLOB NOT NULL, mime_type TEXT ); CREATE TABLE IF NOT EXISTS generation_shares ( id INTEGER PRIMARY KEY AUTOINCREMENT, share_token TEXT UNIQUE NOT NULL, prompt_id TEXT NOT NULL, owner_label TEXT NOT NULL, created_at TEXT NOT NULL ); """ def _connect(db_path: Path | None = None) -> sqlite3.Connection: path = db_path if db_path is not None else _DB_PATH conn = sqlite3.connect(str(path), check_same_thread=False) conn.row_factory = sqlite3.Row return conn def _detect_mime(data: bytes) -> str: """Detect MIME type from magic bytes.""" if data[:8] == b"\x89PNG\r\n\x1a\n": return "image/png" if data[:2] == b"\xff\xd8": return "image/jpeg" if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP": return "image/webp" if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"AVI ": return "video/x-msvideo" if len(data) >= 8 and data[4:8] == b"ftyp": return "video/mp4" if data[:4] == b"\x1aE\xdf\xa3": # EBML (WebM/MKV) return "video/webm" return "application/octet-stream" def init_db(db_path: Path = _DB_PATH) -> None: """Create tables if they don't exist. Accepts a path for testability.""" global _DB_PATH _DB_PATH = db_path with _connect(db_path) as conn: conn.executescript(_SCHEMA) conn.commit() def record_generation( prompt_id: str, source: str, user_label: str | None, overrides_dict: dict[str, Any] | None, seed: int | None, ) -> int: """Insert a generation history row. Returns the auto-increment ``id``.""" overrides_json = json.dumps(overrides_dict) if overrides_dict is not None else None created_at = datetime.now(timezone.utc).isoformat() with _connect() as conn: cur = conn.execute( """ INSERT INTO generation_history (prompt_id, source, user_label, overrides, seed, created_at) VALUES (?, ?, ?, ?, ?, ?) """, (prompt_id, source, user_label, overrides_json, seed, created_at), ) conn.commit() return cur.lastrowid # type: ignore[return-value] def record_file(generation_id: int, filename: str, file_data: bytes) -> None: """Insert a file BLOB row, auto-detecting MIME type from magic bytes.""" mime_type = _detect_mime(file_data) with _connect() as conn: conn.execute( """ INSERT INTO generation_files (generation_id, filename, file_data, mime_type) VALUES (?, ?, ?, ?) """, (generation_id, filename, file_data, mime_type), ) conn.commit() def _rows_to_history(conn: sqlite3.Connection, rows) -> list[dict]: """Convert raw generation_history rows (with optional share_token) to dicts.""" result: list[dict] = [] for row in rows: d = dict(row) if d["overrides"]: try: d["overrides"] = json.loads(d["overrides"]) except (json.JSONDecodeError, TypeError): d["overrides"] = {} else: d["overrides"] = {} files = conn.execute( "SELECT filename FROM generation_files WHERE generation_id = ?", (d["id"],), ).fetchall() d["file_paths"] = [f["filename"] for f in files] result.append(d) return result def get_history(limit: int = 50) -> list[dict]: """Return recent generation rows (newest first) with a ``file_paths`` list.""" with _connect() as conn: rows = conn.execute( """ SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at, s.share_token FROM generation_history h LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label ORDER BY h.id DESC LIMIT ? """, (limit,), ).fetchall() return _rows_to_history(conn, rows) def get_history_for_user(user_label: str, limit: int = 50) -> list[dict]: """Return recent generation rows for a specific user (newest first).""" with _connect() as conn: rows = conn.execute( """ SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at, s.share_token FROM generation_history h LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ? WHERE h.user_label = ? ORDER BY h.id DESC LIMIT ? """, (user_label, user_label, limit), ).fetchall() return _rows_to_history(conn, rows) def get_generation(prompt_id: str) -> dict | None: """Return the generation_history row for *prompt_id*, or None.""" with _connect() as conn: row = conn.execute( "SELECT id, prompt_id, user_label FROM generation_history WHERE prompt_id = ?", (prompt_id,), ).fetchone() return dict(row) if row else None def get_generation_full(prompt_id: str) -> dict | None: """Return overrides (parsed) + seed for *prompt_id*, or None if not found.""" with _connect() as conn: row = conn.execute( "SELECT prompt_id, user_label, overrides, seed FROM generation_history WHERE prompt_id = ?", (prompt_id,), ).fetchone() if row is None: return None d = dict(row) if d["overrides"]: try: d["overrides"] = json.loads(d["overrides"]) except (json.JSONDecodeError, TypeError): d["overrides"] = {} else: d["overrides"] = {} return d def search_history_for_user(user_label: str, query: str, limit: int = 50) -> list[dict]: """Return history rows where the overrides JSON contains *query* (case-insensitive).""" with _connect() as conn: rows = conn.execute( """ SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at, s.share_token FROM generation_history h LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ? WHERE h.user_label = ? AND LOWER(h.overrides) LIKE LOWER(?) ORDER BY h.id DESC LIMIT ? """, (user_label, user_label, f"%{query}%", limit), ).fetchall() return _rows_to_history(conn, rows) def search_history(query: str, limit: int = 50) -> list[dict]: """Admin version: search all users' history for *query* in overrides JSON.""" with _connect() as conn: rows = conn.execute( """ SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at, s.share_token FROM generation_history h LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label WHERE LOWER(h.overrides) LIKE LOWER(?) ORDER BY h.id DESC LIMIT ? """, (f"%{query}%", limit), ).fetchall() return _rows_to_history(conn, rows) def create_share(prompt_id: str, owner_label: str) -> str: """Create a share token for *prompt_id*. Idempotent — returns the same token if one exists.""" token = secrets.token_urlsafe(32) created_at = datetime.now(timezone.utc).isoformat() with _connect() as conn: conn.execute( """ INSERT OR IGNORE INTO generation_shares (share_token, prompt_id, owner_label, created_at) VALUES (?, ?, ?, ?) """, (token, prompt_id, owner_label, created_at), ) conn.commit() row = conn.execute( "SELECT share_token FROM generation_shares WHERE prompt_id = ? AND owner_label = ?", (prompt_id, owner_label), ).fetchone() return row["share_token"] def revoke_share(prompt_id: str, owner_label: str) -> bool: """Delete the share token for *prompt_id*. Returns True if a row was deleted.""" with _connect() as conn: cur = conn.execute( "DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?", (prompt_id, owner_label), ) conn.commit() return cur.rowcount > 0 def get_share_by_token(token: str) -> dict | None: """Return generation info for a share token, or None if not found/revoked.""" with _connect() as conn: row = conn.execute( """ SELECT h.prompt_id, h.overrides, h.seed, h.created_at FROM generation_shares s JOIN generation_history h ON h.prompt_id = s.prompt_id WHERE s.share_token = ? """, (token,), ).fetchone() if row is None: return None d = dict(row) if d["overrides"]: try: d["overrides"] = json.loads(d["overrides"]) except (json.JSONDecodeError, TypeError): d["overrides"] = {} else: d["overrides"] = {} return d def get_files(prompt_id: str) -> list[dict]: """Return all output files for *prompt_id* as ``[{filename, data, mime_type}]``.""" with _connect() as conn: gen_row = conn.execute( "SELECT id FROM generation_history WHERE prompt_id = ?", (prompt_id,), ).fetchone() if not gen_row: return [] files = conn.execute( "SELECT filename, file_data, mime_type FROM generation_files WHERE generation_id = ?", (gen_row["id"],), ).fetchall() return [ { "filename": f["filename"], "data": bytes(f["file_data"]), "mime_type": f["mime_type"], } for f in files ]