Files
comfy-discord-web/generation_db.py
Khoa (Revenovich) Tran Gia 6004b000a7 manual submit
2026-03-07 21:49:16 +07:00

501 lines
18 KiB
Python

"""
generation_db.py
================
SQLite persistence for ComfyUI generation history and output file blobs.
Two tables
----------
generation_history : one row per prompt submitted to ComfyUI
generation_files : one row per output file (image / video) as a BLOB
The module-level ``_DB_PATH`` is set by :func:`init_db`; all other
functions use that path so callers never need to pass it around.
"""
from __future__ import annotations
import json
import secrets
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
_DB_PATH: Path = Path(__file__).parent / "generation_history.db"
_SCHEMA = """
CREATE TABLE IF NOT EXISTS generation_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
prompt_id TEXT UNIQUE NOT NULL,
source TEXT NOT NULL,
user_label TEXT,
overrides TEXT,
seed INTEGER,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS generation_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
generation_id INTEGER NOT NULL REFERENCES generation_history(id),
filename TEXT NOT NULL,
file_data BLOB NOT NULL,
mime_type TEXT
);
CREATE TABLE IF NOT EXISTS generation_shares (
id INTEGER PRIMARY KEY AUTOINCREMENT,
share_token TEXT UNIQUE NOT NULL,
prompt_id TEXT NOT NULL,
owner_label TEXT NOT NULL,
created_at TEXT NOT NULL,
is_public INTEGER NOT NULL DEFAULT 0,
expires_at TEXT,
max_views INTEGER,
view_count INTEGER NOT NULL DEFAULT 0
);
"""
def _migrate_shares_table(conn: sqlite3.Connection) -> None:
migrations = [
"ALTER TABLE generation_shares ADD COLUMN is_public INTEGER NOT NULL DEFAULT 0",
"ALTER TABLE generation_shares ADD COLUMN expires_at TEXT",
"ALTER TABLE generation_shares ADD COLUMN max_views INTEGER",
"ALTER TABLE generation_shares ADD COLUMN view_count INTEGER NOT NULL DEFAULT 0",
]
for sql in migrations:
try:
conn.execute(sql)
except sqlite3.OperationalError:
pass # column already exists
conn.commit()
def _connect(db_path: Path | None = None) -> sqlite3.Connection:
path = db_path if db_path is not None else _DB_PATH
conn = sqlite3.connect(str(path), check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def _detect_mime(data: bytes) -> str:
"""Detect MIME type from magic bytes."""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return "image/png"
if data[:2] == b"\xff\xd8":
return "image/jpeg"
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP":
return "image/webp"
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"AVI ":
return "video/x-msvideo"
if len(data) >= 8 and data[4:8] == b"ftyp":
return "video/mp4"
if data[:4] == b"\x1aE\xdf\xa3": # EBML (WebM/MKV)
return "video/webm"
return "application/octet-stream"
def init_db(db_path: Path = _DB_PATH) -> None:
"""Create tables if they don't exist. Accepts a path for testability."""
global _DB_PATH
_DB_PATH = db_path
with _connect(db_path) as conn:
conn.executescript(_SCHEMA)
conn.commit()
_migrate_shares_table(conn)
def record_generation(
prompt_id: str,
source: str,
user_label: str | None,
overrides_dict: dict[str, Any] | None,
seed: int | None,
) -> int:
"""Insert a generation history row. Returns the auto-increment ``id``."""
overrides_json = json.dumps(overrides_dict) if overrides_dict is not None else None
created_at = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
cur = conn.execute(
"""
INSERT INTO generation_history
(prompt_id, source, user_label, overrides, seed, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(prompt_id, source, user_label, overrides_json, seed, created_at),
)
conn.commit()
return cur.lastrowid # type: ignore[return-value]
def record_file(generation_id: int, filename: str, file_data: bytes) -> int:
"""Insert a file BLOB row, auto-detecting MIME type from magic bytes. Returns the row id."""
mime_type = _detect_mime(file_data)
with _connect() as conn:
cur = conn.execute(
"""
INSERT INTO generation_files (generation_id, filename, file_data, mime_type)
VALUES (?, ?, ?, ?)
""",
(generation_id, filename, file_data, mime_type),
)
conn.commit()
return cur.lastrowid # type: ignore[return-value]
def _rows_to_history(conn: sqlite3.Connection, rows) -> list[dict]:
"""Convert raw generation_history rows (with optional share_token) to dicts."""
result: list[dict] = []
for row in rows:
d = dict(row)
if d["overrides"]:
try:
d["overrides"] = json.loads(d["overrides"])
except (json.JSONDecodeError, TypeError):
d["overrides"] = {}
else:
d["overrides"] = {}
files = conn.execute(
"SELECT filename FROM generation_files WHERE generation_id = ?",
(d["id"],),
).fetchall()
d["file_paths"] = [f["filename"] for f in files]
result.append(d)
return result
def get_history(limit: int = 50) -> list[dict]:
"""Return recent generation rows (newest first) with a ``file_paths`` list."""
with _connect() as conn:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
s.share_token, s.is_public AS share_is_public,
s.expires_at AS share_expires_at, s.max_views AS share_max_views,
s.view_count AS share_view_count
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
ORDER BY h.id DESC LIMIT ?
""",
(limit,),
).fetchall()
return _rows_to_history(conn, rows)
def get_history_for_user(user_label: str, limit: int = 50) -> list[dict]:
"""Return recent generation rows for a specific user (newest first)."""
with _connect() as conn:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
s.share_token, s.is_public AS share_is_public,
s.expires_at AS share_expires_at, s.max_views AS share_max_views,
s.view_count AS share_view_count
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
WHERE h.user_label = ?
ORDER BY h.id DESC LIMIT ?
""",
(user_label, user_label, limit),
).fetchall()
return _rows_to_history(conn, rows)
def get_generation(prompt_id: str) -> dict | None:
"""Return the generation_history row for *prompt_id*, or None."""
with _connect() as conn:
row = conn.execute(
"SELECT id, prompt_id, user_label FROM generation_history WHERE prompt_id = ?",
(prompt_id,),
).fetchone()
return dict(row) if row else None
def get_generation_full(prompt_id: str) -> dict | None:
"""Return overrides (parsed) + seed for *prompt_id*, or None if not found."""
with _connect() as conn:
row = conn.execute(
"SELECT prompt_id, user_label, overrides, seed FROM generation_history WHERE prompt_id = ?",
(prompt_id,),
).fetchone()
if row is None:
return None
d = dict(row)
if d["overrides"]:
try:
d["overrides"] = json.loads(d["overrides"])
except (json.JSONDecodeError, TypeError):
d["overrides"] = {}
else:
d["overrides"] = {}
return d
def search_history_for_user(user_label: str, query: str, limit: int = 50) -> list[dict]:
"""Return history rows where the overrides JSON contains *query* (case-insensitive)."""
with _connect() as conn:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
s.share_token, s.is_public AS share_is_public,
s.expires_at AS share_expires_at, s.max_views AS share_max_views,
s.view_count AS share_view_count
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
WHERE h.user_label = ? AND LOWER(h.overrides) LIKE LOWER(?)
ORDER BY h.id DESC LIMIT ?
""",
(user_label, user_label, f"%{query}%", limit),
).fetchall()
return _rows_to_history(conn, rows)
def search_history(query: str, limit: int = 50) -> list[dict]:
"""Admin version: search all users' history for *query* in overrides JSON."""
with _connect() as conn:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
s.share_token, s.is_public AS share_is_public,
s.expires_at AS share_expires_at, s.max_views AS share_max_views,
s.view_count AS share_view_count
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
WHERE LOWER(h.overrides) LIKE LOWER(?)
ORDER BY h.id DESC LIMIT ?
""",
(f"%{query}%", limit),
).fetchall()
return _rows_to_history(conn, rows)
def _is_share_expired(share_row: dict) -> bool:
"""Return True if the share has passed its time or view limits."""
if share_row["expires_at"] and datetime.fromisoformat(share_row["expires_at"]) <= datetime.now(timezone.utc):
return True
if share_row["max_views"] is not None and share_row["view_count"] >= share_row["max_views"]:
return True
return False
def _is_share_streaming_expired(share_row: dict) -> bool:
"""Expiry check for file-streaming calls (no view_count increment).
Uses strict-greater-than so that files remain accessible within the
same page view that just consumed the last allowed view."""
if share_row["expires_at"] and datetime.fromisoformat(share_row["expires_at"]) <= datetime.now(timezone.utc):
return True
if share_row["max_views"] is not None and share_row["view_count"] > share_row["max_views"]:
return True
return False
def get_active_share_for_prompt(prompt_id: str, owner_label: str) -> dict | None:
"""Return the active (non-expired) share row for *prompt_id*+*owner_label*, or None.
If a row exists but is already expired, it is auto-deleted and None is returned
so the caller can create a new share immediately.
"""
with _connect() as conn:
row = conn.execute(
"""
SELECT share_token, is_public, expires_at, max_views, view_count
FROM generation_shares
WHERE prompt_id = ? AND owner_label = ?
""",
(prompt_id, owner_label),
).fetchone()
if row is None:
return None
d = dict(row)
if _is_share_expired(d):
conn.execute(
"DELETE FROM generation_shares WHERE share_token = ?",
(d["share_token"],),
)
conn.commit()
return None
return d
def create_share(
prompt_id: str,
owner_label: str,
*,
is_public: bool = False,
expires_at: str | None = None,
max_views: int | None = None,
) -> dict:
"""Insert a fresh share row. Returns dict with share_token, is_public, expires_at, max_views."""
token = secrets.token_urlsafe(32)
created_at = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
conn.execute(
"""
INSERT INTO generation_shares
(share_token, prompt_id, owner_label, created_at, is_public, expires_at, max_views, view_count)
VALUES (?, ?, ?, ?, ?, ?, ?, 0)
""",
(token, prompt_id, owner_label, created_at, int(is_public), expires_at, max_views),
)
conn.commit()
row = conn.execute(
"SELECT share_token, is_public, expires_at, max_views FROM generation_shares WHERE share_token = ?",
(token,),
).fetchone()
return dict(row)
def revoke_share(prompt_id: str, owner_label: str | None = None) -> bool:
"""Delete the share token for *prompt_id*.
If *owner_label* is provided, only delete that user's share.
If None (admin), delete any share for the prompt_id.
Returns True if a row was deleted.
"""
with _connect() as conn:
if owner_label is not None:
cur = conn.execute(
"DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
(prompt_id, owner_label),
)
else:
cur = conn.execute(
"DELETE FROM generation_shares WHERE prompt_id = ?",
(prompt_id,),
)
conn.commit()
return cur.rowcount > 0
def get_share_by_token(token: str) -> dict | None:
"""Return generation info for a share token (incrementing view_count), or None if not found/expired."""
with _connect() as conn:
row = conn.execute(
"""
SELECT s.share_token, s.is_public, s.expires_at, s.max_views, s.view_count,
h.prompt_id, h.overrides, h.seed, h.created_at
FROM generation_shares s
JOIN generation_history h ON h.prompt_id = s.prompt_id
WHERE s.share_token = ?
""",
(token,),
).fetchone()
if row is None:
return None
d = dict(row)
if _is_share_expired(d):
conn.execute("DELETE FROM generation_shares WHERE share_token = ?", (token,))
conn.commit()
return None
# Increment view count
conn.execute(
"UPDATE generation_shares SET view_count = view_count + 1 WHERE share_token = ?",
(token,),
)
conn.commit()
if d["overrides"]:
try:
d["overrides"] = json.loads(d["overrides"])
except (json.JSONDecodeError, TypeError):
d["overrides"] = {}
else:
d["overrides"] = {}
return d
def get_share_meta(token: str) -> dict | None:
"""Return share metadata without incrementing view_count. Used by file-streaming endpoints."""
with _connect() as conn:
row = conn.execute(
"""
SELECT s.share_token, s.is_public, s.expires_at, s.max_views, s.view_count,
h.prompt_id, h.overrides, h.seed, h.created_at
FROM generation_shares s
JOIN generation_history h ON h.prompt_id = s.prompt_id
WHERE s.share_token = ?
""",
(token,),
).fetchone()
if row is None:
return None
d = dict(row)
if _is_share_streaming_expired(d):
conn.execute("DELETE FROM generation_shares WHERE share_token = ?", (token,))
conn.commit()
return None
if d["overrides"]:
try:
d["overrides"] = json.loads(d["overrides"])
except (json.JSONDecodeError, TypeError):
d["overrides"] = {}
else:
d["overrides"] = {}
return d
def get_file_ids_for_prompt(prompt_id: str) -> list[int]:
"""Return generation_files.id values for all files belonging to prompt_id."""
with _connect() as conn:
rows = conn.execute(
"""SELECT gf.id FROM generation_files gf
JOIN generation_history gh ON gh.id = gf.generation_id
WHERE gh.prompt_id = ?""",
(prompt_id,),
).fetchall()
return [r["id"] for r in rows]
def get_generation_ids_for_file_ids(file_ids: list[int]) -> list[int]:
"""Return distinct generation_id values for the given generation_files row ids."""
if not file_ids:
return []
placeholders = ",".join("?" * len(file_ids))
with _connect() as conn:
rows = conn.execute(
f"SELECT DISTINCT generation_id FROM generation_files WHERE id IN ({placeholders})",
tuple(file_ids),
).fetchall()
return [r["generation_id"] for r in rows]
def get_file_ids_for_generation_ids(gen_ids: list[int]) -> dict[int, list[int]]:
"""Return {gen_id: [file_id, …]} for the given generation_history row ids."""
if not gen_ids:
return {}
placeholders = ",".join("?" * len(gen_ids))
with _connect() as conn:
rows = conn.execute(
f"SELECT generation_id, id FROM generation_files WHERE generation_id IN ({placeholders})",
tuple(gen_ids),
).fetchall()
result: dict[int, list[int]] = {gid: [] for gid in gen_ids}
for row in rows:
result[row["generation_id"]].append(row["id"])
return result
def get_files(prompt_id: str) -> list[dict]:
"""Return all output files for *prompt_id* as ``[{filename, data, mime_type}]``."""
with _connect() as conn:
gen_row = conn.execute(
"SELECT id FROM generation_history WHERE prompt_id = ?",
(prompt_id,),
).fetchone()
if not gen_row:
return []
files = conn.execute(
"SELECT filename, file_data, mime_type FROM generation_files WHERE generation_id = ?",
(gen_row["id"],),
).fetchall()
return [
{
"filename": f["filename"],
"data": bytes(f["file_data"]),
"mime_type": f["mime_type"],
}
for f in files
]