Initial commit — ComfyUI Discord bot + web UI
Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI (React/TS/Vite/Tailwind), ComfyUI integration, generation history DB, preset manager, workflow inspector, and all supporting modules. Excluded from tracking: .env, invite_tokens.json, *.db (SQLite), current-workflow-changes.json, user_settings/, presets/, logs/, web-static/ (build output), frontend/node_modules/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
322
generation_db.py
Normal file
322
generation_db.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
generation_db.py
|
||||
================
|
||||
|
||||
SQLite persistence for ComfyUI generation history and output file blobs.
|
||||
|
||||
Two tables
|
||||
----------
|
||||
generation_history : one row per prompt submitted to ComfyUI
|
||||
generation_files : one row per output file (image / video) as a BLOB
|
||||
|
||||
The module-level ``_DB_PATH`` is set by :func:`init_db`; all other
|
||||
functions use that path so callers never need to pass it around.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
_DB_PATH: Path = Path(__file__).parent / "generation_history.db"
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS generation_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
prompt_id TEXT UNIQUE NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
user_label TEXT,
|
||||
overrides TEXT,
|
||||
seed INTEGER,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS generation_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
generation_id INTEGER NOT NULL REFERENCES generation_history(id),
|
||||
filename TEXT NOT NULL,
|
||||
file_data BLOB NOT NULL,
|
||||
mime_type TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS generation_shares (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
share_token TEXT UNIQUE NOT NULL,
|
||||
prompt_id TEXT NOT NULL,
|
||||
owner_label TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
def _connect(db_path: Path | None = None) -> sqlite3.Connection:
|
||||
path = db_path if db_path is not None else _DB_PATH
|
||||
conn = sqlite3.connect(str(path), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def _detect_mime(data: bytes) -> str:
|
||||
"""Detect MIME type from magic bytes."""
|
||||
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
return "image/png"
|
||||
if data[:2] == b"\xff\xd8":
|
||||
return "image/jpeg"
|
||||
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"AVI ":
|
||||
return "video/x-msvideo"
|
||||
if len(data) >= 8 and data[4:8] == b"ftyp":
|
||||
return "video/mp4"
|
||||
if data[:4] == b"\x1aE\xdf\xa3": # EBML (WebM/MKV)
|
||||
return "video/webm"
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def init_db(db_path: Path = _DB_PATH) -> None:
|
||||
"""Create tables if they don't exist. Accepts a path for testability."""
|
||||
global _DB_PATH
|
||||
_DB_PATH = db_path
|
||||
with _connect(db_path) as conn:
|
||||
conn.executescript(_SCHEMA)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def record_generation(
|
||||
prompt_id: str,
|
||||
source: str,
|
||||
user_label: str | None,
|
||||
overrides_dict: dict[str, Any] | None,
|
||||
seed: int | None,
|
||||
) -> int:
|
||||
"""Insert a generation history row. Returns the auto-increment ``id``."""
|
||||
overrides_json = json.dumps(overrides_dict) if overrides_dict is not None else None
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
with _connect() as conn:
|
||||
cur = conn.execute(
|
||||
"""
|
||||
INSERT INTO generation_history
|
||||
(prompt_id, source, user_label, overrides, seed, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(prompt_id, source, user_label, overrides_json, seed, created_at),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.lastrowid # type: ignore[return-value]
|
||||
|
||||
|
||||
def record_file(generation_id: int, filename: str, file_data: bytes) -> None:
|
||||
"""Insert a file BLOB row, auto-detecting MIME type from magic bytes."""
|
||||
mime_type = _detect_mime(file_data)
|
||||
with _connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO generation_files (generation_id, filename, file_data, mime_type)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(generation_id, filename, file_data, mime_type),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _rows_to_history(conn: sqlite3.Connection, rows) -> list[dict]:
|
||||
"""Convert raw generation_history rows (with optional share_token) to dicts."""
|
||||
result: list[dict] = []
|
||||
for row in rows:
|
||||
d = dict(row)
|
||||
if d["overrides"]:
|
||||
try:
|
||||
d["overrides"] = json.loads(d["overrides"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
d["overrides"] = {}
|
||||
else:
|
||||
d["overrides"] = {}
|
||||
|
||||
files = conn.execute(
|
||||
"SELECT filename FROM generation_files WHERE generation_id = ?",
|
||||
(d["id"],),
|
||||
).fetchall()
|
||||
d["file_paths"] = [f["filename"] for f in files]
|
||||
result.append(d)
|
||||
return result
|
||||
|
||||
|
||||
def get_history(limit: int = 50) -> list[dict]:
|
||||
"""Return recent generation rows (newest first) with a ``file_paths`` list."""
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||
s.share_token
|
||||
FROM generation_history h
|
||||
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
|
||||
ORDER BY h.id DESC LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return _rows_to_history(conn, rows)
|
||||
|
||||
|
||||
def get_history_for_user(user_label: str, limit: int = 50) -> list[dict]:
|
||||
"""Return recent generation rows for a specific user (newest first)."""
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||
s.share_token
|
||||
FROM generation_history h
|
||||
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
|
||||
WHERE h.user_label = ?
|
||||
ORDER BY h.id DESC LIMIT ?
|
||||
""",
|
||||
(user_label, user_label, limit),
|
||||
).fetchall()
|
||||
return _rows_to_history(conn, rows)
|
||||
|
||||
|
||||
def get_generation(prompt_id: str) -> dict | None:
|
||||
"""Return the generation_history row for *prompt_id*, or None."""
|
||||
with _connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT id, prompt_id, user_label FROM generation_history WHERE prompt_id = ?",
|
||||
(prompt_id,),
|
||||
).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def get_generation_full(prompt_id: str) -> dict | None:
|
||||
"""Return overrides (parsed) + seed for *prompt_id*, or None if not found."""
|
||||
with _connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT prompt_id, user_label, overrides, seed FROM generation_history WHERE prompt_id = ?",
|
||||
(prompt_id,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
d = dict(row)
|
||||
if d["overrides"]:
|
||||
try:
|
||||
d["overrides"] = json.loads(d["overrides"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
d["overrides"] = {}
|
||||
else:
|
||||
d["overrides"] = {}
|
||||
return d
|
||||
|
||||
|
||||
def search_history_for_user(user_label: str, query: str, limit: int = 50) -> list[dict]:
|
||||
"""Return history rows where the overrides JSON contains *query* (case-insensitive)."""
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||
s.share_token
|
||||
FROM generation_history h
|
||||
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
|
||||
WHERE h.user_label = ? AND LOWER(h.overrides) LIKE LOWER(?)
|
||||
ORDER BY h.id DESC LIMIT ?
|
||||
""",
|
||||
(user_label, user_label, f"%{query}%", limit),
|
||||
).fetchall()
|
||||
return _rows_to_history(conn, rows)
|
||||
|
||||
|
||||
def search_history(query: str, limit: int = 50) -> list[dict]:
|
||||
"""Admin version: search all users' history for *query* in overrides JSON."""
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||
s.share_token
|
||||
FROM generation_history h
|
||||
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
|
||||
WHERE LOWER(h.overrides) LIKE LOWER(?)
|
||||
ORDER BY h.id DESC LIMIT ?
|
||||
""",
|
||||
(f"%{query}%", limit),
|
||||
).fetchall()
|
||||
return _rows_to_history(conn, rows)
|
||||
|
||||
|
||||
def create_share(prompt_id: str, owner_label: str) -> str:
|
||||
"""Create a share token for *prompt_id*. Idempotent — returns the same token if one exists."""
|
||||
token = secrets.token_urlsafe(32)
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
with _connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO generation_shares (share_token, prompt_id, owner_label, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(token, prompt_id, owner_label, created_at),
|
||||
)
|
||||
conn.commit()
|
||||
row = conn.execute(
|
||||
"SELECT share_token FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
|
||||
(prompt_id, owner_label),
|
||||
).fetchone()
|
||||
return row["share_token"]
|
||||
|
||||
|
||||
def revoke_share(prompt_id: str, owner_label: str) -> bool:
|
||||
"""Delete the share token for *prompt_id*. Returns True if a row was deleted."""
|
||||
with _connect() as conn:
|
||||
cur = conn.execute(
|
||||
"DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
|
||||
(prompt_id, owner_label),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
def get_share_by_token(token: str) -> dict | None:
|
||||
"""Return generation info for a share token, or None if not found/revoked."""
|
||||
with _connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT h.prompt_id, h.overrides, h.seed, h.created_at
|
||||
FROM generation_shares s
|
||||
JOIN generation_history h ON h.prompt_id = s.prompt_id
|
||||
WHERE s.share_token = ?
|
||||
""",
|
||||
(token,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
d = dict(row)
|
||||
if d["overrides"]:
|
||||
try:
|
||||
d["overrides"] = json.loads(d["overrides"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
d["overrides"] = {}
|
||||
else:
|
||||
d["overrides"] = {}
|
||||
return d
|
||||
|
||||
|
||||
def get_files(prompt_id: str) -> list[dict]:
|
||||
"""Return all output files for *prompt_id* as ``[{filename, data, mime_type}]``."""
|
||||
with _connect() as conn:
|
||||
gen_row = conn.execute(
|
||||
"SELECT id FROM generation_history WHERE prompt_id = ?",
|
||||
(prompt_id,),
|
||||
).fetchone()
|
||||
if not gen_row:
|
||||
return []
|
||||
|
||||
files = conn.execute(
|
||||
"SELECT filename, file_data, mime_type FROM generation_files WHERE generation_id = ?",
|
||||
(gen_row["id"],),
|
||||
).fetchall()
|
||||
return [
|
||||
{
|
||||
"filename": f["filename"],
|
||||
"data": bytes(f["file_data"]),
|
||||
"mime_type": f["mime_type"],
|
||||
}
|
||||
for f in files
|
||||
]
|
||||
Reference in New Issue
Block a user