manual submit

This commit is contained in:
Khoa (Revenovich) Tran Gia
2026-03-07 21:49:16 +07:00
parent 1748cbf8d2
commit 6004b000a7
39 changed files with 5794 additions and 614 deletions

View File

@@ -178,6 +178,7 @@ def create_app() -> FastAPI:
from web.routers.share_router import router as share_router
from web.routers.workflow_router import router as workflow_router
from web.routers.ws_router import router as ws_router
from web.routers.faces_router import router as faces_router
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
@@ -191,6 +192,7 @@ def create_app() -> FastAPI:
app.include_router(share_router, prefix="/api/share", tags=["share"])
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
app.include_router(ws_router, tags=["ws"])
app.include_router(faces_router, prefix="/api/faces", tags=["faces"])
from web.routers.ws_router import websocket_endpoint as _ws_endpoint
@@ -212,6 +214,11 @@ def create_app() -> FastAPI:
@app.on_event("startup")
async def _startup():
try:
import face_db
face_db.init_db()
except Exception as _exc:
logger.warning("face_db.init_db() failed (non-fatal): %s", _exc)
asyncio.create_task(_status_ticker())
asyncio.create_task(_server_state_poller())
logger.info("Web background tasks started")
@@ -224,12 +231,12 @@ def create_app() -> FastAPI:
# ---------------------------------------------------------------------------
async def _status_ticker() -> None:
"""Broadcast status_snapshot to all clients every 5 seconds."""
"""Broadcast status_snapshot to all clients every 2 seconds."""
from web.deps import get_bot, get_comfy, get_config
bus = get_bus()
while True:
await asyncio.sleep(5)
await asyncio.sleep(2)
try:
bot = get_bot()
comfy = get_comfy()

View File

@@ -37,12 +37,17 @@ _COOKIE_NAME = "ttb_session"
def _get_secret() -> str:
from web.deps import get_config
cfg = get_config()
if cfg and cfg.web_secret_key:
return cfg.web_secret_key
raise RuntimeError(
"WEB_SECRET_KEY must be set in the environment — "
"refusing to run with an insecure default."
)
key = cfg.web_secret_key if cfg else ""
if not key:
raise RuntimeError(
"WEB_SECRET_KEY must be set in the environment — "
"refusing to run with an insecure default."
)
if len(key) < 32:
raise RuntimeError(
"WEB_SECRET_KEY is too short (got %d chars, need ≥ 32)." % len(key)
)
return key
def create_jwt(label: str, *, admin: bool = False, expire_hours: int = 8) -> str:
@@ -102,6 +107,13 @@ def require_auth(ttb_session: Optional[str] = Cookie(default=None)) -> dict:
return payload
def optional_auth(ttb_session: Optional[str] = Cookie(default=None)) -> Optional[dict]:
"""Returns decoded JWT payload or None. Never raises 401."""
if not ttb_session:
return None
return decode_jwt(ttb_session)
def require_admin(user: dict = Depends(require_auth)) -> dict:
"""
FastAPI dependency that requires an admin JWT.

451
web/routers/faces_router.py Normal file
View File

@@ -0,0 +1,451 @@
"""GET/POST /api/faces/* — face recognition endpoints."""
from __future__ import annotations
import logging
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from pydantic import BaseModel
from face_service import _SIMILARITY_THRESHOLD
from web.auth import require_admin
router = APIRouter()
logger = logging.getLogger(__name__)
def _auto_link_for_person(person_id: int, face_db_module, ungrouped: list[dict]) -> int:
"""Auto-link ungrouped unidentified detections similar to person_id. Returns count linked."""
ref_embeddings = face_db_module.get_person_embeddings(person_id)
if not ref_embeddings or not ungrouped:
return 0
ref_M = np.stack(ref_embeddings)
ref_M_norm = ref_M / (np.linalg.norm(ref_M, axis=1, keepdims=True) + 1e-8)
count = 0
for ue in ungrouped:
norm_ue = ue["embedding"] / (np.linalg.norm(ue["embedding"]) + 1e-8)
if float(np.max(ref_M_norm @ norm_ue)) >= _SIMILARITY_THRESHOLD:
face_db_module.link_detection_to_person(ue["id"], person_id)
count += 1
return count
@router.get("/persons")
async def list_persons(_: dict = Depends(require_admin)):
"""List all known persons."""
import face_db
return {"persons": face_db.list_persons()}
@router.get("/persons/check")
async def check_person_name(name: str, _: dict = Depends(require_admin)):
"""Check whether a person name is already taken (case-insensitive)."""
import face_db
return {"exists": face_db.person_name_exists(name)}
@router.get("/persons/search")
async def search_persons(
q: str = "",
limit: int = Query(10, ge=1, le=50),
_: dict = Depends(require_admin),
):
"""Search persons by name/alias substring."""
import face_db
all_persons = face_db.list_persons()
q_lower = q.strip().lower()
if q_lower:
filtered = [
p for p in all_persons
if q_lower in p["name"].lower()
or any(q_lower in a["alias"].lower() for a in p["aliases"])
]
else:
filtered = all_persons
return {"persons": filtered[:limit]}
@router.get("/persons/{person_id}")
async def get_person(person_id: int, _: dict = Depends(require_admin)):
"""Get a single person with aliases."""
import face_db
person = face_db.get_person(person_id)
if person is None:
raise HTTPException(404, f"Person {person_id} not found")
return person
@router.get("/crop/{detection_id}")
async def get_face_crop(detection_id: int, _: dict = Depends(require_admin)):
"""Return a JPEG face crop for the given detection id."""
import asyncio
from face_service import get_face_service
svc = get_face_service()
if not svc.available:
raise HTTPException(503, "Face service not available")
loop = asyncio.get_event_loop()
crop = await loop.run_in_executor(svc._executor, svc.get_face_crop, detection_id)
if crop is None:
raise HTTPException(404, "Face crop not found")
return Response(content=crop, media_type="image/jpeg")
class _IdentifyItem(BaseModel):
detection_id: int
name: str
use_existing: bool = False
class _IdentifyRequest(BaseModel):
identifications: list[_IdentifyItem]
@router.get("/detections/unidentified")
async def list_unidentified_detections(
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
_: dict = Depends(require_admin),
):
"""List unidentified face detections from input images (paginated)."""
import face_db
detections, total = face_db.get_unidentified_input_detections(limit=limit, offset=offset)
return {"detections": detections, "total": total}
class _AliasRequest(BaseModel):
alias: str
@router.post("/persons/{person_id}/aliases")
async def add_person_alias(
person_id: int,
body: _AliasRequest,
_: dict = Depends(require_admin),
):
"""Add an alias to a person."""
import face_db
alias = body.alias.strip()
if not alias:
raise HTTPException(400, "Alias cannot be empty")
if len(alias) > 100:
raise HTTPException(400, "Alias too long (max 100 chars)")
try:
alias_id, _ = face_db.add_alias(person_id, alias)
except ValueError as e:
raise HTTPException(409, str(e)) from e
return {"id": alias_id, "alias": alias}
@router.delete("/persons/{person_id}/aliases/{alias_id}")
async def remove_person_alias(
person_id: int,
alias_id: int,
_: dict = Depends(require_admin),
):
"""Remove an alias from a person."""
import face_db
face_db.remove_alias(alias_id)
return {"ok": True}
class _RenameRequest(BaseModel):
name: str
@router.patch("/persons/{person_id}")
async def rename_person(
person_id: int,
body: _RenameRequest,
_: dict = Depends(require_admin),
):
"""Rename a person."""
import face_db
name = body.name.strip()
if not name:
raise HTTPException(400, "Name cannot be empty")
if len(name) > 100:
raise HTTPException(400, "Name too long (max 100 chars)")
if face_db.get_person(person_id) is None:
raise HTTPException(404, f"Person {person_id} not found")
try:
face_db.rename_person(person_id, name)
except ValueError as e:
raise HTTPException(409, str(e)) from e
return {"ok": True, "person_id": person_id, "name": name}
@router.delete("/persons/{person_id}")
async def delete_person(person_id: int, _: dict = Depends(require_admin)):
"""Delete a person and unidentify all their detections."""
import face_db
if face_db.get_person(person_id) is None:
raise HTTPException(404, f"Person {person_id} not found")
face_db.delete_person(person_id)
return {"ok": True}
@router.get("/persons/{person_id}/detections")
async def get_person_detections(
person_id: int,
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
_: dict = Depends(require_admin),
):
"""List detections for a person (paginated)."""
import face_db
detections, total = face_db.get_detections_for_person(person_id, limit=limit, offset=offset)
return {"detections": detections, "total": total}
class _MergePersonRequest(BaseModel):
other_person_id: int
@router.post("/persons/{person_id}/merge")
async def merge_persons(
person_id: int,
body: _MergePersonRequest,
_: dict = Depends(require_admin),
):
"""Merge another person into this one (survivor keeps their id)."""
import face_db
if person_id == body.other_person_id:
raise HTTPException(400, "Cannot merge person into themselves")
if face_db.get_person(person_id) is None:
raise HTTPException(404, f"Person {person_id} not found")
if face_db.get_person(body.other_person_id) is None:
raise HTTPException(404, f"Person {body.other_person_id} not found")
face_db.merge_persons(person_id, body.other_person_id)
return {"ok": True, "survivor_id": person_id, "absorbed_id": body.other_person_id}
class _ReassignRequest(BaseModel):
person_name: str | None = None
use_existing: bool = False
@router.post("/detections/{detection_id}/reassign")
async def reassign_detection(
detection_id: int,
body: _ReassignRequest,
_: dict = Depends(require_admin),
):
"""
Reassign or unidentify a detection.
- person_name=null → unidentify (set person_id=NULL)
- person_name=str → link to that person (create if needed, or use existing)
"""
import face_db
det = face_db.get_detection(detection_id)
if det is None:
raise HTTPException(404, f"Detection {detection_id} not found")
if body.person_name is None:
face_db.unidentify_detection(detection_id)
return {"detection_id": detection_id, "person_id": None, "unidentified": True}
name = body.person_name.strip()
if not name:
raise HTTPException(400, "Name cannot be empty")
if len(name) > 100:
raise HTTPException(400, "Name too long (max 100 chars)")
exists = face_db.person_name_exists(name)
if exists and not body.use_existing:
raise HTTPException(
409, f"A person named '{name}' already exists. Set use_existing=true to link."
)
person_id, is_new = face_db.get_or_create_person(name)
face_db.link_detection_to_person(detection_id, person_id)
return {"detection_id": detection_id, "person_id": person_id, "is_new": is_new}
class _ClusterRequest(BaseModel):
threshold: float = 0.45
class _MergeGroupsRequest(BaseModel):
keep_group_id: int
discard_group_id: int
class _IdentifyGroupRequest(BaseModel):
name: str
use_existing: bool = False
@router.get("/groups")
async def list_face_groups(_: dict = Depends(require_admin)):
"""List face groups with ≥ 2 unidentified detections."""
import face_db
groups = face_db.get_groups_with_detections()
return {"groups": groups, "total": len(groups)}
@router.get("/groups/{group_id}/detections")
async def get_face_group_detections(group_id: int, _: dict = Depends(require_admin)):
"""Return the unidentified detections belonging to a group (fetched on expand)."""
import face_db
detections = face_db.get_group_detections(group_id)
return {"detections": detections}
@router.post("/groups/compute")
async def compute_face_groups(body: _ClusterRequest, _: dict = Depends(require_admin)):
"""Run full re-cluster of all unidentified faces."""
from face_service import get_face_service
if not 0.3 <= body.threshold <= 0.7:
raise HTTPException(422, "threshold must be between 0.3 and 0.7")
svc = get_face_service()
if not svc.available:
raise HTTPException(503, "Face service not available")
groups = await svc.cluster_unidentified_faces(body.threshold)
total_detections = sum(len(g) for g in groups)
return {
"groups_created": len(groups),
"total_detections_clustered": total_detections,
"threshold": body.threshold,
}
@router.post("/groups/merge")
async def merge_face_groups(body: _MergeGroupsRequest, _: dict = Depends(require_admin)):
"""Merge two groups into one."""
import face_db
if body.keep_group_id == body.discard_group_id:
raise HTTPException(400, "keep_group_id and discard_group_id must differ")
groups_index = {g["id"] for g in face_db.get_groups_with_detections()}
if body.keep_group_id not in groups_index:
raise HTTPException(404, f"Group {body.keep_group_id} not found")
if body.discard_group_id not in groups_index:
raise HTTPException(404, f"Group {body.discard_group_id} not found")
face_db.merge_groups(body.keep_group_id, body.discard_group_id)
return {"ok": True, "surviving_group_id": body.keep_group_id}
@router.post("/groups/{group_id}/identify")
async def identify_face_group(
group_id: int,
body: _IdentifyGroupRequest,
_: dict = Depends(require_admin),
):
"""
Identify all detections in a group as one person.
After identification, auto-links any similar ungrouped unidentified detections.
Works for both input and output source types.
"""
import face_db
name = body.name.strip()
if not name:
raise HTTPException(400, "Name cannot be empty")
if len(name) > 100:
raise HTTPException(400, "Name too long (max 100 chars)")
detections = face_db.get_group_detections(group_id)
if not detections:
raise HTTPException(404, f"Group {group_id} not found or has no unidentified detections")
exists = face_db.person_name_exists(name)
if exists and not body.use_existing:
raise HTTPException(
409, f"A person named '{name}' already exists. Set use_existing=true to link."
)
person_id, is_new = face_db.get_or_create_person(name)
identified_count = 0
for det in detections:
face_db.link_detection_to_person(det["id"], person_id)
identified_count += 1
face_db.delete_group(group_id)
# Post-identify: auto-link ungrouped unidentified detections similar to this person
ungrouped = face_db.get_ungrouped_unidentified_embeddings()
auto_linked_count = _auto_link_for_person(person_id, face_db, ungrouped)
return {
"person_id": person_id,
"person_name": name,
"is_new": is_new,
"identified_count": identified_count,
"auto_linked_count": auto_linked_count,
}
@router.delete("/groups/{group_id}/detections/{detection_id}")
async def remove_group_detection(
group_id: int,
detection_id: int,
_: dict = Depends(require_admin),
):
"""Remove a single detection from its group. Cleans up singleton groups."""
import face_db
det = face_db.get_detection(detection_id)
if det is None or det.get("group_id") != group_id:
raise HTTPException(404, "Detection not found in this group")
face_db.remove_detection_from_group(detection_id)
face_db.delete_singleton_groups()
return {"ok": True}
@router.post("/rescan/outputs")
async def rescan_output_embeddings(_: dict = Depends(require_admin)):
"""Re-detect faces in stored output images to rebuild NULL embeddings."""
import face_db
from face_service import get_face_service
svc = get_face_service()
if not svc.available:
raise HTTPException(503, "Face service not available")
source_ids = face_db.get_null_embedding_output_source_ids()
total_updated = 0
for source_id in source_ids:
updated = await svc.rescan_output_embedding(source_id)
total_updated += updated
return {"processed": len(source_ids), "updated": total_updated}
@router.post("/identify")
async def identify_faces(body: _IdentifyRequest, _: dict = Depends(require_admin)):
"""
Identify one or more face detections by name.
- If the name is new → creates a new person and links the detection.
- If the name exists and ``use_existing=true`` → links to the existing person.
- If the name exists and ``use_existing=false`` → HTTP 409.
- Only detections with ``source_type='input'`` may be identified via the web UI.
"""
import face_db
results = []
for item in body.identifications:
name = item.name.strip()
if not name:
raise HTTPException(400, "Name cannot be empty")
if len(name) > 100:
raise HTTPException(400, "Name too long (max 100 chars)")
det = face_db.get_detection(item.detection_id)
if det is None:
raise HTTPException(404, f"Detection {item.detection_id} not found")
if det["source_type"] != "input":
raise HTTPException(403, "Only input-image detections may be identified via web UI")
exists = face_db.person_name_exists(name)
if exists and not item.use_existing:
raise HTTPException(
409,
f"A person named '{name}' already exists. Set use_existing=true to link.",
)
person_id, is_new = face_db.get_or_create_person(name)
face_db.link_detection_to_person(item.detection_id, person_id)
results.append({
"detection_id": item.detection_id,
"person_id": person_id,
"person_name": name,
"is_new": is_new,
})
# Auto-link similar ungrouped faces for each person identified in this batch
ungrouped = face_db.get_ungrouped_unidentified_embeddings()
unique_pids = {r["person_id"] for r in results}
auto_linked_count = sum(_auto_link_for_person(pid, face_db, ungrouped) for pid in unique_pids)
return {"identifications": results, "auto_linked_count": auto_linked_count}

View File

@@ -17,6 +17,32 @@ router = APIRouter()
logger = logging.getLogger(__name__)
def _materialize_image_slots(
overrides: dict, comfy_input_path: str
) -> tuple[dict, list[str]]:
"""
For each override whose value is an existing ttb_* file, copy it to a
unique name so concurrent jobs each have an immutable copy on disk.
Returns (updated_overrides, paths_to_delete_after_generation).
"""
import shutil
import uuid as _uuid
if not comfy_input_path:
return overrides, []
updated = dict(overrides)
cleanup: list[str] = []
input_dir = Path(comfy_input_path)
for key, val in overrides.items():
if isinstance(val, str) and val.startswith("ttb_") and "." in val:
src = input_dir / val
if src.is_file():
unique_name = f"{src.stem}_{_uuid.uuid4().hex[:8]}{src.suffix}"
shutil.copy2(src, input_dir / unique_name)
updated[key] = unique_name
cleanup.append(str(input_dir / unique_name))
return updated, cleanup
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
@@ -105,7 +131,8 @@ async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
from generation_db import record_generation, record_file
gen_id = record_generation(pid, "web", user_label, overrides_for_gen, seed_used)
for i, img_data in enumerate(images):
record_file(gen_id, f"image_{i:04d}.png", img_data)
file_id = record_file(gen_id, f"image_{i:04d}.png", img_data)
comfy._schedule_face_scan("image", file_id, img_data)
if config and videos:
for vid in videos:
vsub = vid.get("video_subfolder", "")
@@ -116,7 +143,9 @@ async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
else Path(config.comfy_output_path) / vname
)
try:
record_file(gen_id, vname, vpath.read_bytes())
vid_data = vpath.read_bytes()
file_id = record_file(gen_id, vname, vid_data)
comfy._schedule_face_scan("video", file_id, vid_data)
except OSError:
pass
except Exception as exc:
@@ -163,25 +192,32 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
registry = get_user_registry()
count = max(1, min(body.count, 20)) # cap at 20
async def _run_one():
# Use the user's own state and template
if registry:
user_sm = registry.get_state_manager(user_label)
user_template = registry.get_workflow_template(user_label)
else:
user_sm = comfy.state_manager
user_template = comfy.workflow_manager.get_workflow_template()
# --- snapshot state at queue time, not at execution time ---
if registry:
_user_sm = registry.get_state_manager(user_label)
_user_template = registry.get_workflow_template(user_label)
else:
_user_sm = comfy.state_manager
_user_template = comfy.workflow_manager.get_workflow_template()
if not user_template:
base_overrides = _user_sm.get_overrides()
if body.overrides:
base_overrides = {**base_overrides, **body.overrides}
_config = get_config()
async def _run_one(overrides: dict, cleanup_paths: list[str]):
if not _user_template:
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": "No workflow template loaded"
})
for p in cleanup_paths:
try:
Path(p).unlink(missing_ok=True)
except Exception:
pass
return
overrides = user_sm.get_overrides()
if body.overrides:
overrides = {**overrides, **body.overrides}
import uuid
pid = str(uuid.uuid4())
@@ -190,7 +226,7 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
"node": node, "prompt_id": pid_
}))
workflow, applied = comfy.inspector.inject_overrides(user_template, overrides)
workflow, applied = comfy.inspector.inject_overrides(_user_template, overrides)
seed_used = applied.get("seed")
comfy.last_seed = seed_used
@@ -201,17 +237,23 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": str(exc)
})
for p in cleanup_paths:
try:
Path(p).unlink(missing_ok=True)
except Exception:
pass
return
comfy.last_prompt_id = pid
comfy.total_generated += 1
config = get_config()
config = _config
try:
from generation_db import record_generation, record_file
gen_id = record_generation(pid, "web", user_label, overrides, seed_used)
for i, img_data in enumerate(images):
record_file(gen_id, f"image_{i:04d}.png", img_data)
file_id = record_file(gen_id, f"image_{i:04d}.png", img_data)
comfy._schedule_face_scan("image", file_id, img_data)
if config and videos:
for vid in videos:
vsub = vid.get("video_subfolder", "")
@@ -222,7 +264,9 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
else Path(config.comfy_output_path) / vname
)
try:
record_file(gen_id, vname, vpath.read_bytes())
vid_data = vpath.read_bytes()
file_id = record_file(gen_id, vname, vid_data)
comfy._schedule_face_scan("video", file_id, vid_data)
except OSError:
pass
except Exception as exc:
@@ -236,6 +280,13 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
config.media_upload_pass,
))
# Clean up unique image copies now that ComfyUI has ingested them
for p in cleanup_paths:
try:
Path(p).unlink(missing_ok=True)
except Exception:
pass
await bus.broadcast("queue_update", {"prompt_id": pid, "status": "complete"})
await bus.broadcast_to_user(user_label, "generation_complete", {
"prompt_id": pid,
@@ -246,7 +297,10 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
depth = await comfy.get_queue_depth()
for _ in range(count):
asyncio.create_task(_run_one())
job_overrides, cleanup = _materialize_image_slots(
base_overrides, _config.comfy_input_path if _config else ""
)
asyncio.create_task(_run_one(job_overrides, cleanup))
return {
"queued": True,

View File

@@ -3,12 +3,14 @@ POST /api/history/{prompt_id}/share; DELETE /api/history/{prompt_id}/share"""
from __future__ import annotations
import base64
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request
from fastapi.responses import Response
from pydantic import BaseModel
from web.auth import require_auth
from web.auth import require_admin, require_auth
router = APIRouter()
@@ -31,22 +33,135 @@ def _assert_owns(prompt_id: str, user: dict) -> None:
async def get_history(
user: dict = Depends(require_auth),
q: Optional[str] = Query(None, description="Keyword to search in overrides JSON"),
persons: list[str] = Query(default=[], alias="persons", description="Filter by person name/alias substring (repeatable)"),
):
"""Return generation history. Admins see all; regular users see only their own.
Pass ?q=keyword to filter by prompt text or any override field."""
Pass ?q=keyword to filter by prompt text or any override field.
Pass ?persons=name (repeatable) to filter by persons detected in output images."""
import face_db as face_db_mod
from generation_db import (
get_history as db_get_history,
get_history_for_user,
search_history,
search_history_for_user,
get_generation_ids_for_file_ids,
get_file_ids_for_generation_ids,
)
# Person filter: OR union across all named persons, then filter rows
gen_id_filter: list[int] | None = None
all_file_ids: set[int] = set()
active_persons = [p.strip() for p in persons if p.strip()]
if active_persons:
for p in active_persons:
ids = face_db_mod.get_source_ids_for_person_query(p, "output")
all_file_ids.update(ids)
if not all_file_ids:
return {"history": []}
gen_ids = get_generation_ids_for_file_ids(list(all_file_ids))
if not gen_ids:
return {"history": []}
gen_id_filter = gen_ids
if q and q.strip():
if user.get("admin"):
return {"history": search_history(q.strip(), limit=50)}
return {"history": search_history_for_user(user["sub"], q.strip(), limit=50)}
if user.get("admin"):
return {"history": db_get_history(limit=50)}
return {"history": get_history_for_user(user["sub"], limit=50)}
rows = search_history(q.strip(), limit=50)
else:
rows = search_history_for_user(user["sub"], q.strip(), limit=50)
else:
if user.get("admin"):
rows = db_get_history(limit=50)
else:
rows = get_history_for_user(user["sub"], limit=50)
if gen_id_filter is not None:
gen_id_set = set(gen_id_filter)
rows = [r for r in rows if r["id"] in gen_id_set]
# Annotate each row with detected_persons when person filtering is active
if active_persons and rows:
row_ids = [r["id"] for r in rows]
file_ids_per_gen = get_file_ids_for_generation_ids(row_ids)
all_annotate_ids = [fid for fids in file_ids_per_gen.values() for fid in fids]
person_map = face_db_mod.get_persons_for_source_id_map(all_annotate_ids, "output") if all_annotate_ids else {}
for row in rows:
row["detected_persons"] = list({
name
for fid in file_ids_per_gen.get(row["id"], [])
for name in person_map.get(fid, [])
})
return {"history": rows}
@router.get("/{prompt_id}/persons")
async def get_generation_persons(prompt_id: str, user: dict = Depends(require_auth)):
"""Return persons detected in the output images of a generation."""
_assert_owns(prompt_id, user)
import generation_db as gen_db
import face_db as face_db_mod
file_ids = gen_db.get_file_ids_for_prompt(prompt_id)
persons = face_db_mod.get_persons_for_source_ids(file_ids, "output")
return {"persons": persons}
class _AddPersonRequest(BaseModel):
name: str
@router.post("/{prompt_id}/persons")
async def add_generation_person(
prompt_id: str,
body: _AddPersonRequest,
_: dict = Depends(require_admin),
):
"""Manually tag a person to a generation's output files."""
import generation_db as gen_db
import face_db as face_db_mod
file_ids = gen_db.get_file_ids_for_prompt(prompt_id)
if not file_ids:
raise HTTPException(404, f"Generation {prompt_id!r} not found")
name = body.name.strip()
if not name:
raise HTTPException(400, "Name cannot be empty")
if len(name) > 100:
raise HTTPException(400, "Name too long (max 100 chars)")
person_id, _ = face_db_mod.get_or_create_person(name)
# Only add if not already tagged on any file of this generation
existing = face_db_mod.get_persons_for_source_ids(file_ids, "output")
if not any(p["id"] == person_id for p in existing):
face_db_mod.insert_detection(
source_type="output",
source_id=file_ids[0],
embedding=None,
bbox={},
face_index=0,
person_id=person_id,
)
return {"person_id": person_id, "name": name}
@router.delete("/{prompt_id}/persons/{person_id}")
async def remove_generation_person(
prompt_id: str,
person_id: int,
_: dict = Depends(require_admin),
):
"""Remove all face detections linking a person to this generation's output files."""
import generation_db as gen_db
import face_db as face_db_mod
file_ids = gen_db.get_file_ids_for_prompt(prompt_id)
if not file_ids:
raise HTTPException(404, f"Generation {prompt_id!r} not found")
face_db_mod.remove_person_from_source_ids(file_ids, person_id, "output")
return {"ok": True}
@router.get("/{prompt_id}/images")
@@ -101,6 +216,11 @@ async def get_history_file(
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else total - 1
end = min(end, total - 1)
if start < 0 or start > end:
return Response(
status_code=416,
headers={"Content-Range": f"bytes */{total}"},
)
chunk = data[start : end + 1]
return Response(
content=chunk,
@@ -110,6 +230,7 @@ async def get_history_file(
"Content-Range": f"bytes {start}-{end}/{total}",
"Accept-Ranges": "bytes",
"Content-Length": str(len(chunk)),
"Cache-Control": "private",
},
)
@@ -119,25 +240,62 @@ async def get_history_file(
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total),
"Cache-Control": "private",
},
)
class _ShareRequest(BaseModel):
is_public: bool = False
expires_in_hours: Optional[float] = None
max_views: Optional[int] = None
@router.post("/{prompt_id}/share")
async def create_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
async def create_generation_share(
prompt_id: str,
body: _ShareRequest = Body(default_factory=_ShareRequest),
user: dict = Depends(require_auth),
):
"""Create a share token for a generation. Only the owner may share."""
# Use the same 404-for-everything helper to avoid leaking prompt_id existence
_assert_owns(prompt_id, user)
from generation_db import create_share as db_create_share
token = db_create_share(prompt_id, user["sub"])
return {"share_token": token}
from generation_db import (
create_share as db_create_share,
get_active_share_for_prompt as db_get_active_share,
)
existing = db_get_active_share(prompt_id, user["sub"])
if existing:
raise HTTPException(409, "A share already exists — revoke it first")
if body.is_public and body.expires_in_hours is None and body.max_views is None:
raise HTTPException(400, "Public shares must have at least one expiry condition (time or view limit)")
if body.max_views is not None and body.max_views < 1:
raise HTTPException(400, "max_views must be >= 1")
if body.expires_in_hours is not None and body.expires_in_hours <= 0:
raise HTTPException(400, "expires_in_hours must be > 0")
expires_at = None
if body.expires_in_hours is not None:
expires_at = (datetime.now(timezone.utc) + timedelta(hours=body.expires_in_hours)).isoformat()
share = db_create_share(
prompt_id, user["sub"],
is_public=body.is_public,
expires_at=expires_at,
max_views=body.max_views,
)
return {
"share_token": share["share_token"],
"is_public": bool(share["is_public"]),
"expires_at": share["expires_at"],
"max_views": share["max_views"],
}
@router.delete("/{prompt_id}/share")
async def revoke_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
"""Revoke a share token for a generation. Only the owner may revoke."""
"""Revoke a share token for a generation. Admins can revoke any share."""
from generation_db import revoke_share as db_revoke_share
deleted = db_revoke_share(prompt_id, user["sub"])
# Admins pass None for owner_label to delete by prompt_id alone
owner_label = None if user.get("admin") else user["sub"]
deleted = db_revoke_share(prompt_id, owner_label)
if not deleted:
raise HTTPException(404, "No active share found for this generation")
return {"ok": True}

View File

@@ -3,10 +3,11 @@ from __future__ import annotations
import logging
import mimetypes
import re
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, UploadFile
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import Response
from web.auth import require_auth
@@ -15,10 +16,38 @@ from web.deps import get_config, get_user_registry
router = APIRouter()
logger = logging.getLogger(__name__)
_ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
_MAX_UPLOAD_BYTES = 50 * 1024 * 1024 # 50 MB
_SAFE_SLOT_RE = re.compile(r'^[a-zA-Z0-9_\-]+$')
def _validate_slot_key(slot_key: str) -> None:
if not _SAFE_SLOT_RE.match(slot_key):
raise HTTPException(400, "slot_key may only contain letters, digits, hyphens, underscores")
@router.get("")
async def list_inputs(_: dict = Depends(require_auth)):
"""List all input images (Discord + web uploads)."""
async def list_inputs(
_: dict = Depends(require_auth),
persons: list[str] = Query(default=[], alias="persons", description="Filter by person name/alias substring (repeatable)"),
):
"""List all input images (Discord + web uploads). Optionally filter by persons."""
active_persons = [p.strip() for p in persons if p.strip()]
if active_persons:
import face_db as face_db_mod
from input_image_db import get_images_by_ids
all_ids: set[int] = set()
for p in active_persons:
ids = face_db_mod.get_source_ids_for_person_query(p, "input")
all_ids.update(ids)
images = list(get_images_by_ids(list(all_ids))) if all_ids else []
if images:
person_map = face_db_mod.get_persons_for_source_id_map(
[img["id"] for img in images], "input"
)
for img in images:
img["detected_persons"] = person_map.get(img["id"], [])
return images
from input_image_db import get_all_images
rows = get_all_images()
return [dict(r) for r in rows]
@@ -44,8 +73,16 @@ async def upload_input(
if config is None:
raise HTTPException(503, "Config not available")
if slot_key:
_validate_slot_key(slot_key)
data = await file.read()
filename = file.filename or "upload.png"
ext = Path(filename).suffix.lower()
if ext not in _ALLOWED_EXTS:
raise HTTPException(415, f"Unsupported file type '{ext}'. Allowed: {sorted(_ALLOWED_EXTS)}")
if len(data) > _MAX_UPLOAD_BYTES:
raise HTTPException(413, "File too large (max 50 MB)")
from input_image_db import upsert_image, activate_image_for_slot
row_id = upsert_image(
@@ -72,7 +109,30 @@ async def upload_input(
if comfy:
comfy.state_manager.set_override(slot_key, activated_filename)
return {"id": row_id, "filename": filename, "slot_key": slot_key, "activated_filename": activated_filename}
# Face scan — runs synchronously here (~1-2 s); unknown faces returned to UI
pending_faces: list[dict] = []
try:
from face_service import get_face_service
import face_db as _face_db
_face_db.init_db()
svc = get_face_service()
if svc.available:
results = await svc.scan_input_image(row_id, data)
pending_faces = [
{"detection_id": r.detection_id, "face_index": r.face_index, "bbox": r.bbox}
for r in results
if r.matched_person_id is None
]
except Exception as exc:
logger.warning("Face scan failed for upload row_id=%d: %s", row_id, exc)
return {
"id": row_id,
"filename": filename,
"slot_key": slot_key,
"activated_filename": activated_filename,
"pending_faces": pending_faces,
}
@router.post("/{row_id}/activate")
@@ -92,6 +152,7 @@ async def activate_input(
raise HTTPException(404, "Image not found")
user_label: str = user["sub"]
_validate_slot_key(slot_key)
namespaced_key = f"{user_label}_{slot_key}"
try:

View File

@@ -2,28 +2,40 @@
from __future__ import annotations
import base64
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from web.auth import require_auth
from web.auth import optional_auth
router = APIRouter()
def _auth_gate(gen: dict, user: Optional[dict]) -> None:
"""Raise 401 if the share is private and the user is not authenticated."""
if not gen.get("is_public") and user is None:
raise HTTPException(401, "Authentication required")
@router.get("/{token}")
async def get_share(token: str, _: dict = Depends(require_auth)):
"""Fetch share metadata and images. Any authenticated user may view a valid share link."""
async def get_share(token: str, user: Optional[dict] = Depends(optional_auth)):
"""Fetch share metadata and images. Public shares require no login; private shares require auth."""
from generation_db import get_share_by_token, get_files
gen = get_share_by_token(token)
if gen is None:
raise HTTPException(404, "Share not found or revoked")
raise HTTPException(404, "Share not found, expired, or revoked")
_auth_gate(gen, user)
files = get_files(gen["prompt_id"])
return {
"prompt_id": gen["prompt_id"],
"created_at": gen["created_at"],
"overrides": gen["overrides"],
"seed": gen["seed"],
"is_public": bool(gen["is_public"]),
"expires_at": gen["expires_at"],
"max_views": gen["max_views"],
"view_count": gen["view_count"],
"images": [
{
"filename": f["filename"],
@@ -40,13 +52,14 @@ async def get_share_file(
token: str,
filename: str,
request: Request,
_: dict = Depends(require_auth),
user: Optional[dict] = Depends(optional_auth),
):
"""Stream a single output file via share token, with HTTP range support for video seeking."""
from generation_db import get_share_by_token, get_files
gen = get_share_by_token(token)
from generation_db import get_share_meta, get_files
gen = get_share_meta(token)
if gen is None:
raise HTTPException(404, "Share not found or revoked")
raise HTTPException(404, "Share not found, expired, or revoked")
_auth_gate(gen, user)
files = get_files(gen["prompt_id"])
matched = next((f for f in files if f["filename"] == filename), None)
if matched is None:
@@ -63,6 +76,11 @@ async def get_share_file(
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else total - 1
end = min(end, total - 1)
if start < 0 or start > end:
return Response(
status_code=416,
headers={"Content-Range": f"bytes */{total}"},
)
chunk = data[start : end + 1]
return Response(
content=chunk,
@@ -72,6 +90,7 @@ async def get_share_file(
"Content-Range": f"bytes {start}-{end}/{total}",
"Accept-Ranges": "bytes",
"Content-Length": str(len(chunk)),
"Cache-Control": "public, max-age=3600" if gen.get("is_public") else "private",
},
)
@@ -81,5 +100,6 @@ async def get_share_file(
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total),
"Cache-Control": "public, max-age=3600" if gen.get("is_public") else "private",
},
)