manual submit
This commit is contained in:
226
generation_db.py
226
generation_db.py
@@ -48,11 +48,30 @@ CREATE TABLE IF NOT EXISTS generation_shares (
|
||||
share_token TEXT UNIQUE NOT NULL,
|
||||
prompt_id TEXT NOT NULL,
|
||||
owner_label TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
created_at TEXT NOT NULL,
|
||||
is_public INTEGER NOT NULL DEFAULT 0,
|
||||
expires_at TEXT,
|
||||
max_views INTEGER,
|
||||
view_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
def _migrate_shares_table(conn: sqlite3.Connection) -> None:
|
||||
migrations = [
|
||||
"ALTER TABLE generation_shares ADD COLUMN is_public INTEGER NOT NULL DEFAULT 0",
|
||||
"ALTER TABLE generation_shares ADD COLUMN expires_at TEXT",
|
||||
"ALTER TABLE generation_shares ADD COLUMN max_views INTEGER",
|
||||
"ALTER TABLE generation_shares ADD COLUMN view_count INTEGER NOT NULL DEFAULT 0",
|
||||
]
|
||||
for sql in migrations:
|
||||
try:
|
||||
conn.execute(sql)
|
||||
except sqlite3.OperationalError:
|
||||
pass # column already exists
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _connect(db_path: Path | None = None) -> sqlite3.Connection:
|
||||
path = db_path if db_path is not None else _DB_PATH
|
||||
conn = sqlite3.connect(str(path), check_same_thread=False)
|
||||
@@ -84,6 +103,7 @@ def init_db(db_path: Path = _DB_PATH) -> None:
|
||||
with _connect(db_path) as conn:
|
||||
conn.executescript(_SCHEMA)
|
||||
conn.commit()
|
||||
_migrate_shares_table(conn)
|
||||
|
||||
|
||||
def record_generation(
|
||||
@@ -109,11 +129,11 @@ def record_generation(
|
||||
return cur.lastrowid # type: ignore[return-value]
|
||||
|
||||
|
||||
def record_file(generation_id: int, filename: str, file_data: bytes) -> None:
|
||||
"""Insert a file BLOB row, auto-detecting MIME type from magic bytes."""
|
||||
def record_file(generation_id: int, filename: str, file_data: bytes) -> int:
|
||||
"""Insert a file BLOB row, auto-detecting MIME type from magic bytes. Returns the row id."""
|
||||
mime_type = _detect_mime(file_data)
|
||||
with _connect() as conn:
|
||||
conn.execute(
|
||||
cur = conn.execute(
|
||||
"""
|
||||
INSERT INTO generation_files (generation_id, filename, file_data, mime_type)
|
||||
VALUES (?, ?, ?, ?)
|
||||
@@ -121,6 +141,7 @@ def record_file(generation_id: int, filename: str, file_data: bytes) -> None:
|
||||
(generation_id, filename, file_data, mime_type),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.lastrowid # type: ignore[return-value]
|
||||
|
||||
|
||||
def _rows_to_history(conn: sqlite3.Connection, rows) -> list[dict]:
|
||||
@@ -151,7 +172,9 @@ def get_history(limit: int = 50) -> list[dict]:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||
s.share_token
|
||||
s.share_token, s.is_public AS share_is_public,
|
||||
s.expires_at AS share_expires_at, s.max_views AS share_max_views,
|
||||
s.view_count AS share_view_count
|
||||
FROM generation_history h
|
||||
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
|
||||
ORDER BY h.id DESC LIMIT ?
|
||||
@@ -167,7 +190,9 @@ def get_history_for_user(user_label: str, limit: int = 50) -> list[dict]:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||
s.share_token
|
||||
s.share_token, s.is_public AS share_is_public,
|
||||
s.expires_at AS share_expires_at, s.max_views AS share_max_views,
|
||||
s.view_count AS share_view_count
|
||||
FROM generation_history h
|
||||
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
|
||||
WHERE h.user_label = ?
|
||||
@@ -214,7 +239,9 @@ def search_history_for_user(user_label: str, query: str, limit: int = 50) -> lis
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||
s.share_token
|
||||
s.share_token, s.is_public AS share_is_public,
|
||||
s.expires_at AS share_expires_at, s.max_views AS share_max_views,
|
||||
s.view_count AS share_view_count
|
||||
FROM generation_history h
|
||||
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
|
||||
WHERE h.user_label = ? AND LOWER(h.overrides) LIKE LOWER(?)
|
||||
@@ -231,7 +258,9 @@ def search_history(query: str, limit: int = 50) -> list[dict]:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||
s.share_token
|
||||
s.share_token, s.is_public AS share_is_public,
|
||||
s.expires_at AS share_expires_at, s.max_views AS share_max_views,
|
||||
s.view_count AS share_view_count
|
||||
FROM generation_history h
|
||||
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
|
||||
WHERE LOWER(h.overrides) LIKE LOWER(?)
|
||||
@@ -242,43 +271,111 @@ def search_history(query: str, limit: int = 50) -> list[dict]:
|
||||
return _rows_to_history(conn, rows)
|
||||
|
||||
|
||||
def create_share(prompt_id: str, owner_label: str) -> str:
|
||||
"""Create a share token for *prompt_id*. Idempotent — returns the same token if one exists."""
|
||||
def _is_share_expired(share_row: dict) -> bool:
|
||||
"""Return True if the share has passed its time or view limits."""
|
||||
if share_row["expires_at"] and datetime.fromisoformat(share_row["expires_at"]) <= datetime.now(timezone.utc):
|
||||
return True
|
||||
if share_row["max_views"] is not None and share_row["view_count"] >= share_row["max_views"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_share_streaming_expired(share_row: dict) -> bool:
|
||||
"""Expiry check for file-streaming calls (no view_count increment).
|
||||
Uses strict-greater-than so that files remain accessible within the
|
||||
same page view that just consumed the last allowed view."""
|
||||
if share_row["expires_at"] and datetime.fromisoformat(share_row["expires_at"]) <= datetime.now(timezone.utc):
|
||||
return True
|
||||
if share_row["max_views"] is not None and share_row["view_count"] > share_row["max_views"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_active_share_for_prompt(prompt_id: str, owner_label: str) -> dict | None:
|
||||
"""Return the active (non-expired) share row for *prompt_id*+*owner_label*, or None.
|
||||
|
||||
If a row exists but is already expired, it is auto-deleted and None is returned
|
||||
so the caller can create a new share immediately.
|
||||
"""
|
||||
with _connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT share_token, is_public, expires_at, max_views, view_count
|
||||
FROM generation_shares
|
||||
WHERE prompt_id = ? AND owner_label = ?
|
||||
""",
|
||||
(prompt_id, owner_label),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
d = dict(row)
|
||||
if _is_share_expired(d):
|
||||
conn.execute(
|
||||
"DELETE FROM generation_shares WHERE share_token = ?",
|
||||
(d["share_token"],),
|
||||
)
|
||||
conn.commit()
|
||||
return None
|
||||
return d
|
||||
|
||||
|
||||
def create_share(
|
||||
prompt_id: str,
|
||||
owner_label: str,
|
||||
*,
|
||||
is_public: bool = False,
|
||||
expires_at: str | None = None,
|
||||
max_views: int | None = None,
|
||||
) -> dict:
|
||||
"""Insert a fresh share row. Returns dict with share_token, is_public, expires_at, max_views."""
|
||||
token = secrets.token_urlsafe(32)
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
with _connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO generation_shares (share_token, prompt_id, owner_label, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
INSERT INTO generation_shares
|
||||
(share_token, prompt_id, owner_label, created_at, is_public, expires_at, max_views, view_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, 0)
|
||||
""",
|
||||
(token, prompt_id, owner_label, created_at),
|
||||
(token, prompt_id, owner_label, created_at, int(is_public), expires_at, max_views),
|
||||
)
|
||||
conn.commit()
|
||||
row = conn.execute(
|
||||
"SELECT share_token FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
|
||||
(prompt_id, owner_label),
|
||||
"SELECT share_token, is_public, expires_at, max_views FROM generation_shares WHERE share_token = ?",
|
||||
(token,),
|
||||
).fetchone()
|
||||
return row["share_token"]
|
||||
return dict(row)
|
||||
|
||||
|
||||
def revoke_share(prompt_id: str, owner_label: str) -> bool:
|
||||
"""Delete the share token for *prompt_id*. Returns True if a row was deleted."""
|
||||
def revoke_share(prompt_id: str, owner_label: str | None = None) -> bool:
|
||||
"""Delete the share token for *prompt_id*.
|
||||
|
||||
If *owner_label* is provided, only delete that user's share.
|
||||
If None (admin), delete any share for the prompt_id.
|
||||
Returns True if a row was deleted.
|
||||
"""
|
||||
with _connect() as conn:
|
||||
cur = conn.execute(
|
||||
"DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
|
||||
(prompt_id, owner_label),
|
||||
)
|
||||
if owner_label is not None:
|
||||
cur = conn.execute(
|
||||
"DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
|
||||
(prompt_id, owner_label),
|
||||
)
|
||||
else:
|
||||
cur = conn.execute(
|
||||
"DELETE FROM generation_shares WHERE prompt_id = ?",
|
||||
(prompt_id,),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
def get_share_by_token(token: str) -> dict | None:
|
||||
"""Return generation info for a share token, or None if not found/revoked."""
|
||||
"""Return generation info for a share token (incrementing view_count), or None if not found/expired."""
|
||||
with _connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT h.prompt_id, h.overrides, h.seed, h.created_at
|
||||
SELECT s.share_token, s.is_public, s.expires_at, s.max_views, s.view_count,
|
||||
h.prompt_id, h.overrides, h.seed, h.created_at
|
||||
FROM generation_shares s
|
||||
JOIN generation_history h ON h.prompt_id = s.prompt_id
|
||||
WHERE s.share_token = ?
|
||||
@@ -288,6 +385,16 @@ def get_share_by_token(token: str) -> dict | None:
|
||||
if row is None:
|
||||
return None
|
||||
d = dict(row)
|
||||
if _is_share_expired(d):
|
||||
conn.execute("DELETE FROM generation_shares WHERE share_token = ?", (token,))
|
||||
conn.commit()
|
||||
return None
|
||||
# Increment view count
|
||||
conn.execute(
|
||||
"UPDATE generation_shares SET view_count = view_count + 1 WHERE share_token = ?",
|
||||
(token,),
|
||||
)
|
||||
conn.commit()
|
||||
if d["overrides"]:
|
||||
try:
|
||||
d["overrides"] = json.loads(d["overrides"])
|
||||
@@ -298,6 +405,77 @@ def get_share_by_token(token: str) -> dict | None:
|
||||
return d
|
||||
|
||||
|
||||
def get_share_meta(token: str) -> dict | None:
|
||||
"""Return share metadata without incrementing view_count. Used by file-streaming endpoints."""
|
||||
with _connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT s.share_token, s.is_public, s.expires_at, s.max_views, s.view_count,
|
||||
h.prompt_id, h.overrides, h.seed, h.created_at
|
||||
FROM generation_shares s
|
||||
JOIN generation_history h ON h.prompt_id = s.prompt_id
|
||||
WHERE s.share_token = ?
|
||||
""",
|
||||
(token,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
d = dict(row)
|
||||
if _is_share_streaming_expired(d):
|
||||
conn.execute("DELETE FROM generation_shares WHERE share_token = ?", (token,))
|
||||
conn.commit()
|
||||
return None
|
||||
if d["overrides"]:
|
||||
try:
|
||||
d["overrides"] = json.loads(d["overrides"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
d["overrides"] = {}
|
||||
else:
|
||||
d["overrides"] = {}
|
||||
return d
|
||||
|
||||
|
||||
def get_file_ids_for_prompt(prompt_id: str) -> list[int]:
|
||||
"""Return generation_files.id values for all files belonging to prompt_id."""
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""SELECT gf.id FROM generation_files gf
|
||||
JOIN generation_history gh ON gh.id = gf.generation_id
|
||||
WHERE gh.prompt_id = ?""",
|
||||
(prompt_id,),
|
||||
).fetchall()
|
||||
return [r["id"] for r in rows]
|
||||
|
||||
|
||||
def get_generation_ids_for_file_ids(file_ids: list[int]) -> list[int]:
|
||||
"""Return distinct generation_id values for the given generation_files row ids."""
|
||||
if not file_ids:
|
||||
return []
|
||||
placeholders = ",".join("?" * len(file_ids))
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
f"SELECT DISTINCT generation_id FROM generation_files WHERE id IN ({placeholders})",
|
||||
tuple(file_ids),
|
||||
).fetchall()
|
||||
return [r["generation_id"] for r in rows]
|
||||
|
||||
|
||||
def get_file_ids_for_generation_ids(gen_ids: list[int]) -> dict[int, list[int]]:
|
||||
"""Return {gen_id: [file_id, …]} for the given generation_history row ids."""
|
||||
if not gen_ids:
|
||||
return {}
|
||||
placeholders = ",".join("?" * len(gen_ids))
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
f"SELECT generation_id, id FROM generation_files WHERE generation_id IN ({placeholders})",
|
||||
tuple(gen_ids),
|
||||
).fetchall()
|
||||
result: dict[int, list[int]] = {gid: [] for gid in gen_ids}
|
||||
for row in rows:
|
||||
result[row["generation_id"]].append(row["id"])
|
||||
return result
|
||||
|
||||
|
||||
def get_files(prompt_id: str) -> list[dict]:
|
||||
"""Return all output files for *prompt_id* as ``[{filename, data, mime_type}]``."""
|
||||
with _connect() as conn:
|
||||
|
||||
Reference in New Issue
Block a user