manual submit
This commit is contained in:
11
web/app.py
11
web/app.py
@@ -178,6 +178,7 @@ def create_app() -> FastAPI:
|
||||
from web.routers.share_router import router as share_router
|
||||
from web.routers.workflow_router import router as workflow_router
|
||||
from web.routers.ws_router import router as ws_router
|
||||
from web.routers.faces_router import router as faces_router
|
||||
|
||||
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
||||
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
|
||||
@@ -191,6 +192,7 @@ def create_app() -> FastAPI:
|
||||
app.include_router(share_router, prefix="/api/share", tags=["share"])
|
||||
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
|
||||
app.include_router(ws_router, tags=["ws"])
|
||||
app.include_router(faces_router, prefix="/api/faces", tags=["faces"])
|
||||
|
||||
from web.routers.ws_router import websocket_endpoint as _ws_endpoint
|
||||
|
||||
@@ -212,6 +214,11 @@ def create_app() -> FastAPI:
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup():
|
||||
try:
|
||||
import face_db
|
||||
face_db.init_db()
|
||||
except Exception as _exc:
|
||||
logger.warning("face_db.init_db() failed (non-fatal): %s", _exc)
|
||||
asyncio.create_task(_status_ticker())
|
||||
asyncio.create_task(_server_state_poller())
|
||||
logger.info("Web background tasks started")
|
||||
@@ -224,12 +231,12 @@ def create_app() -> FastAPI:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _status_ticker() -> None:
|
||||
"""Broadcast status_snapshot to all clients every 5 seconds."""
|
||||
"""Broadcast status_snapshot to all clients every 2 seconds."""
|
||||
from web.deps import get_bot, get_comfy, get_config
|
||||
bus = get_bus()
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(2)
|
||||
try:
|
||||
bot = get_bot()
|
||||
comfy = get_comfy()
|
||||
|
||||
24
web/auth.py
24
web/auth.py
@@ -37,12 +37,17 @@ _COOKIE_NAME = "ttb_session"
|
||||
def _get_secret() -> str:
|
||||
from web.deps import get_config
|
||||
cfg = get_config()
|
||||
if cfg and cfg.web_secret_key:
|
||||
return cfg.web_secret_key
|
||||
raise RuntimeError(
|
||||
"WEB_SECRET_KEY must be set in the environment — "
|
||||
"refusing to run with an insecure default."
|
||||
)
|
||||
key = cfg.web_secret_key if cfg else ""
|
||||
if not key:
|
||||
raise RuntimeError(
|
||||
"WEB_SECRET_KEY must be set in the environment — "
|
||||
"refusing to run with an insecure default."
|
||||
)
|
||||
if len(key) < 32:
|
||||
raise RuntimeError(
|
||||
"WEB_SECRET_KEY is too short (got %d chars, need ≥ 32)." % len(key)
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def create_jwt(label: str, *, admin: bool = False, expire_hours: int = 8) -> str:
|
||||
@@ -102,6 +107,13 @@ def require_auth(ttb_session: Optional[str] = Cookie(default=None)) -> dict:
|
||||
return payload
|
||||
|
||||
|
||||
def optional_auth(ttb_session: Optional[str] = Cookie(default=None)) -> Optional[dict]:
|
||||
"""Returns decoded JWT payload or None. Never raises 401."""
|
||||
if not ttb_session:
|
||||
return None
|
||||
return decode_jwt(ttb_session)
|
||||
|
||||
|
||||
def require_admin(user: dict = Depends(require_auth)) -> dict:
|
||||
"""
|
||||
FastAPI dependency that requires an admin JWT.
|
||||
|
||||
451
web/routers/faces_router.py
Normal file
451
web/routers/faces_router.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""GET/POST /api/faces/* — face recognition endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from face_service import _SIMILARITY_THRESHOLD
|
||||
from web.auth import require_admin
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _auto_link_for_person(person_id: int, face_db_module, ungrouped: list[dict]) -> int:
|
||||
"""Auto-link ungrouped unidentified detections similar to person_id. Returns count linked."""
|
||||
ref_embeddings = face_db_module.get_person_embeddings(person_id)
|
||||
if not ref_embeddings or not ungrouped:
|
||||
return 0
|
||||
ref_M = np.stack(ref_embeddings)
|
||||
ref_M_norm = ref_M / (np.linalg.norm(ref_M, axis=1, keepdims=True) + 1e-8)
|
||||
count = 0
|
||||
for ue in ungrouped:
|
||||
norm_ue = ue["embedding"] / (np.linalg.norm(ue["embedding"]) + 1e-8)
|
||||
if float(np.max(ref_M_norm @ norm_ue)) >= _SIMILARITY_THRESHOLD:
|
||||
face_db_module.link_detection_to_person(ue["id"], person_id)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
@router.get("/persons")
|
||||
async def list_persons(_: dict = Depends(require_admin)):
|
||||
"""List all known persons."""
|
||||
import face_db
|
||||
return {"persons": face_db.list_persons()}
|
||||
|
||||
|
||||
@router.get("/persons/check")
|
||||
async def check_person_name(name: str, _: dict = Depends(require_admin)):
|
||||
"""Check whether a person name is already taken (case-insensitive)."""
|
||||
import face_db
|
||||
return {"exists": face_db.person_name_exists(name)}
|
||||
|
||||
|
||||
@router.get("/persons/search")
|
||||
async def search_persons(
|
||||
q: str = "",
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""Search persons by name/alias substring."""
|
||||
import face_db
|
||||
all_persons = face_db.list_persons()
|
||||
q_lower = q.strip().lower()
|
||||
if q_lower:
|
||||
filtered = [
|
||||
p for p in all_persons
|
||||
if q_lower in p["name"].lower()
|
||||
or any(q_lower in a["alias"].lower() for a in p["aliases"])
|
||||
]
|
||||
else:
|
||||
filtered = all_persons
|
||||
return {"persons": filtered[:limit]}
|
||||
|
||||
|
||||
@router.get("/persons/{person_id}")
|
||||
async def get_person(person_id: int, _: dict = Depends(require_admin)):
|
||||
"""Get a single person with aliases."""
|
||||
import face_db
|
||||
person = face_db.get_person(person_id)
|
||||
if person is None:
|
||||
raise HTTPException(404, f"Person {person_id} not found")
|
||||
return person
|
||||
|
||||
|
||||
@router.get("/crop/{detection_id}")
|
||||
async def get_face_crop(detection_id: int, _: dict = Depends(require_admin)):
|
||||
"""Return a JPEG face crop for the given detection id."""
|
||||
import asyncio
|
||||
from face_service import get_face_service
|
||||
svc = get_face_service()
|
||||
if not svc.available:
|
||||
raise HTTPException(503, "Face service not available")
|
||||
loop = asyncio.get_event_loop()
|
||||
crop = await loop.run_in_executor(svc._executor, svc.get_face_crop, detection_id)
|
||||
if crop is None:
|
||||
raise HTTPException(404, "Face crop not found")
|
||||
return Response(content=crop, media_type="image/jpeg")
|
||||
|
||||
|
||||
class _IdentifyItem(BaseModel):
|
||||
detection_id: int
|
||||
name: str
|
||||
use_existing: bool = False
|
||||
|
||||
|
||||
class _IdentifyRequest(BaseModel):
|
||||
identifications: list[_IdentifyItem]
|
||||
|
||||
|
||||
@router.get("/detections/unidentified")
|
||||
async def list_unidentified_detections(
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""List unidentified face detections from input images (paginated)."""
|
||||
import face_db
|
||||
detections, total = face_db.get_unidentified_input_detections(limit=limit, offset=offset)
|
||||
return {"detections": detections, "total": total}
|
||||
|
||||
|
||||
class _AliasRequest(BaseModel):
|
||||
alias: str
|
||||
|
||||
|
||||
@router.post("/persons/{person_id}/aliases")
|
||||
async def add_person_alias(
|
||||
person_id: int,
|
||||
body: _AliasRequest,
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""Add an alias to a person."""
|
||||
import face_db
|
||||
alias = body.alias.strip()
|
||||
if not alias:
|
||||
raise HTTPException(400, "Alias cannot be empty")
|
||||
if len(alias) > 100:
|
||||
raise HTTPException(400, "Alias too long (max 100 chars)")
|
||||
try:
|
||||
alias_id, _ = face_db.add_alias(person_id, alias)
|
||||
except ValueError as e:
|
||||
raise HTTPException(409, str(e)) from e
|
||||
return {"id": alias_id, "alias": alias}
|
||||
|
||||
|
||||
@router.delete("/persons/{person_id}/aliases/{alias_id}")
|
||||
async def remove_person_alias(
|
||||
person_id: int,
|
||||
alias_id: int,
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""Remove an alias from a person."""
|
||||
import face_db
|
||||
face_db.remove_alias(alias_id)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class _RenameRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.patch("/persons/{person_id}")
|
||||
async def rename_person(
|
||||
person_id: int,
|
||||
body: _RenameRequest,
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""Rename a person."""
|
||||
import face_db
|
||||
name = body.name.strip()
|
||||
if not name:
|
||||
raise HTTPException(400, "Name cannot be empty")
|
||||
if len(name) > 100:
|
||||
raise HTTPException(400, "Name too long (max 100 chars)")
|
||||
if face_db.get_person(person_id) is None:
|
||||
raise HTTPException(404, f"Person {person_id} not found")
|
||||
try:
|
||||
face_db.rename_person(person_id, name)
|
||||
except ValueError as e:
|
||||
raise HTTPException(409, str(e)) from e
|
||||
return {"ok": True, "person_id": person_id, "name": name}
|
||||
|
||||
|
||||
@router.delete("/persons/{person_id}")
|
||||
async def delete_person(person_id: int, _: dict = Depends(require_admin)):
|
||||
"""Delete a person and unidentify all their detections."""
|
||||
import face_db
|
||||
if face_db.get_person(person_id) is None:
|
||||
raise HTTPException(404, f"Person {person_id} not found")
|
||||
face_db.delete_person(person_id)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/persons/{person_id}/detections")
|
||||
async def get_person_detections(
|
||||
person_id: int,
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""List detections for a person (paginated)."""
|
||||
import face_db
|
||||
detections, total = face_db.get_detections_for_person(person_id, limit=limit, offset=offset)
|
||||
return {"detections": detections, "total": total}
|
||||
|
||||
|
||||
class _MergePersonRequest(BaseModel):
|
||||
other_person_id: int
|
||||
|
||||
|
||||
@router.post("/persons/{person_id}/merge")
|
||||
async def merge_persons(
|
||||
person_id: int,
|
||||
body: _MergePersonRequest,
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""Merge another person into this one (survivor keeps their id)."""
|
||||
import face_db
|
||||
if person_id == body.other_person_id:
|
||||
raise HTTPException(400, "Cannot merge person into themselves")
|
||||
if face_db.get_person(person_id) is None:
|
||||
raise HTTPException(404, f"Person {person_id} not found")
|
||||
if face_db.get_person(body.other_person_id) is None:
|
||||
raise HTTPException(404, f"Person {body.other_person_id} not found")
|
||||
face_db.merge_persons(person_id, body.other_person_id)
|
||||
return {"ok": True, "survivor_id": person_id, "absorbed_id": body.other_person_id}
|
||||
|
||||
|
||||
class _ReassignRequest(BaseModel):
|
||||
person_name: str | None = None
|
||||
use_existing: bool = False
|
||||
|
||||
|
||||
@router.post("/detections/{detection_id}/reassign")
|
||||
async def reassign_detection(
|
||||
detection_id: int,
|
||||
body: _ReassignRequest,
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
Reassign or unidentify a detection.
|
||||
- person_name=null → unidentify (set person_id=NULL)
|
||||
- person_name=str → link to that person (create if needed, or use existing)
|
||||
"""
|
||||
import face_db
|
||||
det = face_db.get_detection(detection_id)
|
||||
if det is None:
|
||||
raise HTTPException(404, f"Detection {detection_id} not found")
|
||||
if body.person_name is None:
|
||||
face_db.unidentify_detection(detection_id)
|
||||
return {"detection_id": detection_id, "person_id": None, "unidentified": True}
|
||||
name = body.person_name.strip()
|
||||
if not name:
|
||||
raise HTTPException(400, "Name cannot be empty")
|
||||
if len(name) > 100:
|
||||
raise HTTPException(400, "Name too long (max 100 chars)")
|
||||
exists = face_db.person_name_exists(name)
|
||||
if exists and not body.use_existing:
|
||||
raise HTTPException(
|
||||
409, f"A person named '{name}' already exists. Set use_existing=true to link."
|
||||
)
|
||||
person_id, is_new = face_db.get_or_create_person(name)
|
||||
face_db.link_detection_to_person(detection_id, person_id)
|
||||
return {"detection_id": detection_id, "person_id": person_id, "is_new": is_new}
|
||||
|
||||
|
||||
class _ClusterRequest(BaseModel):
|
||||
threshold: float = 0.45
|
||||
|
||||
|
||||
class _MergeGroupsRequest(BaseModel):
|
||||
keep_group_id: int
|
||||
discard_group_id: int
|
||||
|
||||
|
||||
class _IdentifyGroupRequest(BaseModel):
|
||||
name: str
|
||||
use_existing: bool = False
|
||||
|
||||
|
||||
@router.get("/groups")
|
||||
async def list_face_groups(_: dict = Depends(require_admin)):
|
||||
"""List face groups with ≥ 2 unidentified detections."""
|
||||
import face_db
|
||||
groups = face_db.get_groups_with_detections()
|
||||
return {"groups": groups, "total": len(groups)}
|
||||
|
||||
|
||||
@router.get("/groups/{group_id}/detections")
|
||||
async def get_face_group_detections(group_id: int, _: dict = Depends(require_admin)):
|
||||
"""Return the unidentified detections belonging to a group (fetched on expand)."""
|
||||
import face_db
|
||||
detections = face_db.get_group_detections(group_id)
|
||||
return {"detections": detections}
|
||||
|
||||
|
||||
@router.post("/groups/compute")
|
||||
async def compute_face_groups(body: _ClusterRequest, _: dict = Depends(require_admin)):
|
||||
"""Run full re-cluster of all unidentified faces."""
|
||||
from face_service import get_face_service
|
||||
if not 0.3 <= body.threshold <= 0.7:
|
||||
raise HTTPException(422, "threshold must be between 0.3 and 0.7")
|
||||
svc = get_face_service()
|
||||
if not svc.available:
|
||||
raise HTTPException(503, "Face service not available")
|
||||
groups = await svc.cluster_unidentified_faces(body.threshold)
|
||||
total_detections = sum(len(g) for g in groups)
|
||||
return {
|
||||
"groups_created": len(groups),
|
||||
"total_detections_clustered": total_detections,
|
||||
"threshold": body.threshold,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/groups/merge")
|
||||
async def merge_face_groups(body: _MergeGroupsRequest, _: dict = Depends(require_admin)):
|
||||
"""Merge two groups into one."""
|
||||
import face_db
|
||||
if body.keep_group_id == body.discard_group_id:
|
||||
raise HTTPException(400, "keep_group_id and discard_group_id must differ")
|
||||
groups_index = {g["id"] for g in face_db.get_groups_with_detections()}
|
||||
if body.keep_group_id not in groups_index:
|
||||
raise HTTPException(404, f"Group {body.keep_group_id} not found")
|
||||
if body.discard_group_id not in groups_index:
|
||||
raise HTTPException(404, f"Group {body.discard_group_id} not found")
|
||||
face_db.merge_groups(body.keep_group_id, body.discard_group_id)
|
||||
return {"ok": True, "surviving_group_id": body.keep_group_id}
|
||||
|
||||
|
||||
@router.post("/groups/{group_id}/identify")
|
||||
async def identify_face_group(
|
||||
group_id: int,
|
||||
body: _IdentifyGroupRequest,
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
Identify all detections in a group as one person.
|
||||
After identification, auto-links any similar ungrouped unidentified detections.
|
||||
Works for both input and output source types.
|
||||
"""
|
||||
import face_db
|
||||
name = body.name.strip()
|
||||
if not name:
|
||||
raise HTTPException(400, "Name cannot be empty")
|
||||
if len(name) > 100:
|
||||
raise HTTPException(400, "Name too long (max 100 chars)")
|
||||
|
||||
detections = face_db.get_group_detections(group_id)
|
||||
if not detections:
|
||||
raise HTTPException(404, f"Group {group_id} not found or has no unidentified detections")
|
||||
|
||||
exists = face_db.person_name_exists(name)
|
||||
if exists and not body.use_existing:
|
||||
raise HTTPException(
|
||||
409, f"A person named '{name}' already exists. Set use_existing=true to link."
|
||||
)
|
||||
|
||||
person_id, is_new = face_db.get_or_create_person(name)
|
||||
identified_count = 0
|
||||
for det in detections:
|
||||
face_db.link_detection_to_person(det["id"], person_id)
|
||||
identified_count += 1
|
||||
|
||||
face_db.delete_group(group_id)
|
||||
|
||||
# Post-identify: auto-link ungrouped unidentified detections similar to this person
|
||||
ungrouped = face_db.get_ungrouped_unidentified_embeddings()
|
||||
auto_linked_count = _auto_link_for_person(person_id, face_db, ungrouped)
|
||||
|
||||
return {
|
||||
"person_id": person_id,
|
||||
"person_name": name,
|
||||
"is_new": is_new,
|
||||
"identified_count": identified_count,
|
||||
"auto_linked_count": auto_linked_count,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/groups/{group_id}/detections/{detection_id}")
|
||||
async def remove_group_detection(
|
||||
group_id: int,
|
||||
detection_id: int,
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""Remove a single detection from its group. Cleans up singleton groups."""
|
||||
import face_db
|
||||
det = face_db.get_detection(detection_id)
|
||||
if det is None or det.get("group_id") != group_id:
|
||||
raise HTTPException(404, "Detection not found in this group")
|
||||
face_db.remove_detection_from_group(detection_id)
|
||||
face_db.delete_singleton_groups()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/rescan/outputs")
|
||||
async def rescan_output_embeddings(_: dict = Depends(require_admin)):
|
||||
"""Re-detect faces in stored output images to rebuild NULL embeddings."""
|
||||
import face_db
|
||||
from face_service import get_face_service
|
||||
svc = get_face_service()
|
||||
if not svc.available:
|
||||
raise HTTPException(503, "Face service not available")
|
||||
source_ids = face_db.get_null_embedding_output_source_ids()
|
||||
total_updated = 0
|
||||
for source_id in source_ids:
|
||||
updated = await svc.rescan_output_embedding(source_id)
|
||||
total_updated += updated
|
||||
return {"processed": len(source_ids), "updated": total_updated}
|
||||
|
||||
|
||||
@router.post("/identify")
|
||||
async def identify_faces(body: _IdentifyRequest, _: dict = Depends(require_admin)):
|
||||
"""
|
||||
Identify one or more face detections by name.
|
||||
|
||||
- If the name is new → creates a new person and links the detection.
|
||||
- If the name exists and ``use_existing=true`` → links to the existing person.
|
||||
- If the name exists and ``use_existing=false`` → HTTP 409.
|
||||
- Only detections with ``source_type='input'`` may be identified via the web UI.
|
||||
"""
|
||||
import face_db
|
||||
results = []
|
||||
for item in body.identifications:
|
||||
name = item.name.strip()
|
||||
if not name:
|
||||
raise HTTPException(400, "Name cannot be empty")
|
||||
if len(name) > 100:
|
||||
raise HTTPException(400, "Name too long (max 100 chars)")
|
||||
|
||||
det = face_db.get_detection(item.detection_id)
|
||||
if det is None:
|
||||
raise HTTPException(404, f"Detection {item.detection_id} not found")
|
||||
if det["source_type"] != "input":
|
||||
raise HTTPException(403, "Only input-image detections may be identified via web UI")
|
||||
|
||||
exists = face_db.person_name_exists(name)
|
||||
if exists and not item.use_existing:
|
||||
raise HTTPException(
|
||||
409,
|
||||
f"A person named '{name}' already exists. Set use_existing=true to link.",
|
||||
)
|
||||
|
||||
person_id, is_new = face_db.get_or_create_person(name)
|
||||
face_db.link_detection_to_person(item.detection_id, person_id)
|
||||
|
||||
results.append({
|
||||
"detection_id": item.detection_id,
|
||||
"person_id": person_id,
|
||||
"person_name": name,
|
||||
"is_new": is_new,
|
||||
})
|
||||
|
||||
# Auto-link similar ungrouped faces for each person identified in this batch
|
||||
ungrouped = face_db.get_ungrouped_unidentified_embeddings()
|
||||
unique_pids = {r["person_id"] for r in results}
|
||||
auto_linked_count = sum(_auto_link_for_person(pid, face_db, ungrouped) for pid in unique_pids)
|
||||
return {"identifications": results, "auto_linked_count": auto_linked_count}
|
||||
@@ -17,6 +17,32 @@ router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _materialize_image_slots(
|
||||
overrides: dict, comfy_input_path: str
|
||||
) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
For each override whose value is an existing ttb_* file, copy it to a
|
||||
unique name so concurrent jobs each have an immutable copy on disk.
|
||||
Returns (updated_overrides, paths_to_delete_after_generation).
|
||||
"""
|
||||
import shutil
|
||||
import uuid as _uuid
|
||||
if not comfy_input_path:
|
||||
return overrides, []
|
||||
updated = dict(overrides)
|
||||
cleanup: list[str] = []
|
||||
input_dir = Path(comfy_input_path)
|
||||
for key, val in overrides.items():
|
||||
if isinstance(val, str) and val.startswith("ttb_") and "." in val:
|
||||
src = input_dir / val
|
||||
if src.is_file():
|
||||
unique_name = f"{src.stem}_{_uuid.uuid4().hex[:8]}{src.suffix}"
|
||||
shutil.copy2(src, input_dir / unique_name)
|
||||
updated[key] = unique_name
|
||||
cleanup.append(str(input_dir / unique_name))
|
||||
return updated, cleanup
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
@@ -105,7 +131,8 @@ async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
|
||||
from generation_db import record_generation, record_file
|
||||
gen_id = record_generation(pid, "web", user_label, overrides_for_gen, seed_used)
|
||||
for i, img_data in enumerate(images):
|
||||
record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||
file_id = record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||
comfy._schedule_face_scan("image", file_id, img_data)
|
||||
if config and videos:
|
||||
for vid in videos:
|
||||
vsub = vid.get("video_subfolder", "")
|
||||
@@ -116,7 +143,9 @@ async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
|
||||
else Path(config.comfy_output_path) / vname
|
||||
)
|
||||
try:
|
||||
record_file(gen_id, vname, vpath.read_bytes())
|
||||
vid_data = vpath.read_bytes()
|
||||
file_id = record_file(gen_id, vname, vid_data)
|
||||
comfy._schedule_face_scan("video", file_id, vid_data)
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
@@ -163,25 +192,32 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
|
||||
registry = get_user_registry()
|
||||
count = max(1, min(body.count, 20)) # cap at 20
|
||||
|
||||
async def _run_one():
|
||||
# Use the user's own state and template
|
||||
if registry:
|
||||
user_sm = registry.get_state_manager(user_label)
|
||||
user_template = registry.get_workflow_template(user_label)
|
||||
else:
|
||||
user_sm = comfy.state_manager
|
||||
user_template = comfy.workflow_manager.get_workflow_template()
|
||||
# --- snapshot state at queue time, not at execution time ---
|
||||
if registry:
|
||||
_user_sm = registry.get_state_manager(user_label)
|
||||
_user_template = registry.get_workflow_template(user_label)
|
||||
else:
|
||||
_user_sm = comfy.state_manager
|
||||
_user_template = comfy.workflow_manager.get_workflow_template()
|
||||
|
||||
if not user_template:
|
||||
base_overrides = _user_sm.get_overrides()
|
||||
if body.overrides:
|
||||
base_overrides = {**base_overrides, **body.overrides}
|
||||
|
||||
_config = get_config()
|
||||
|
||||
async def _run_one(overrides: dict, cleanup_paths: list[str]):
|
||||
if not _user_template:
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": None, "error": "No workflow template loaded"
|
||||
})
|
||||
for p in cleanup_paths:
|
||||
try:
|
||||
Path(p).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
overrides = user_sm.get_overrides()
|
||||
if body.overrides:
|
||||
overrides = {**overrides, **body.overrides}
|
||||
|
||||
import uuid
|
||||
pid = str(uuid.uuid4())
|
||||
|
||||
@@ -190,7 +226,7 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
|
||||
"node": node, "prompt_id": pid_
|
||||
}))
|
||||
|
||||
workflow, applied = comfy.inspector.inject_overrides(user_template, overrides)
|
||||
workflow, applied = comfy.inspector.inject_overrides(_user_template, overrides)
|
||||
seed_used = applied.get("seed")
|
||||
comfy.last_seed = seed_used
|
||||
|
||||
@@ -201,17 +237,23 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
|
||||
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||
"prompt_id": None, "error": str(exc)
|
||||
})
|
||||
for p in cleanup_paths:
|
||||
try:
|
||||
Path(p).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
comfy.last_prompt_id = pid
|
||||
comfy.total_generated += 1
|
||||
|
||||
config = get_config()
|
||||
config = _config
|
||||
try:
|
||||
from generation_db import record_generation, record_file
|
||||
gen_id = record_generation(pid, "web", user_label, overrides, seed_used)
|
||||
for i, img_data in enumerate(images):
|
||||
record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||
file_id = record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||
comfy._schedule_face_scan("image", file_id, img_data)
|
||||
if config and videos:
|
||||
for vid in videos:
|
||||
vsub = vid.get("video_subfolder", "")
|
||||
@@ -222,7 +264,9 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
|
||||
else Path(config.comfy_output_path) / vname
|
||||
)
|
||||
try:
|
||||
record_file(gen_id, vname, vpath.read_bytes())
|
||||
vid_data = vpath.read_bytes()
|
||||
file_id = record_file(gen_id, vname, vid_data)
|
||||
comfy._schedule_face_scan("video", file_id, vid_data)
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
@@ -236,6 +280,13 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
|
||||
config.media_upload_pass,
|
||||
))
|
||||
|
||||
# Clean up unique image copies now that ComfyUI has ingested them
|
||||
for p in cleanup_paths:
|
||||
try:
|
||||
Path(p).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await bus.broadcast("queue_update", {"prompt_id": pid, "status": "complete"})
|
||||
await bus.broadcast_to_user(user_label, "generation_complete", {
|
||||
"prompt_id": pid,
|
||||
@@ -246,7 +297,10 @@ async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_au
|
||||
|
||||
depth = await comfy.get_queue_depth()
|
||||
for _ in range(count):
|
||||
asyncio.create_task(_run_one())
|
||||
job_overrides, cleanup = _materialize_image_slots(
|
||||
base_overrides, _config.comfy_input_path if _config else ""
|
||||
)
|
||||
asyncio.create_task(_run_one(job_overrides, cleanup))
|
||||
|
||||
return {
|
||||
"queued": True,
|
||||
|
||||
@@ -3,12 +3,14 @@ POST /api/history/{prompt_id}/share; DELETE /api/history/{prompt_id}/share"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.auth import require_admin, require_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -31,22 +33,135 @@ def _assert_owns(prompt_id: str, user: dict) -> None:
|
||||
async def get_history(
|
||||
user: dict = Depends(require_auth),
|
||||
q: Optional[str] = Query(None, description="Keyword to search in overrides JSON"),
|
||||
persons: list[str] = Query(default=[], alias="persons", description="Filter by person name/alias substring (repeatable)"),
|
||||
):
|
||||
"""Return generation history. Admins see all; regular users see only their own.
|
||||
Pass ?q=keyword to filter by prompt text or any override field."""
|
||||
Pass ?q=keyword to filter by prompt text or any override field.
|
||||
Pass ?persons=name (repeatable) to filter by persons detected in output images."""
|
||||
import face_db as face_db_mod
|
||||
from generation_db import (
|
||||
get_history as db_get_history,
|
||||
get_history_for_user,
|
||||
search_history,
|
||||
search_history_for_user,
|
||||
get_generation_ids_for_file_ids,
|
||||
get_file_ids_for_generation_ids,
|
||||
)
|
||||
|
||||
# Person filter: OR union across all named persons, then filter rows
|
||||
gen_id_filter: list[int] | None = None
|
||||
all_file_ids: set[int] = set()
|
||||
active_persons = [p.strip() for p in persons if p.strip()]
|
||||
if active_persons:
|
||||
for p in active_persons:
|
||||
ids = face_db_mod.get_source_ids_for_person_query(p, "output")
|
||||
all_file_ids.update(ids)
|
||||
if not all_file_ids:
|
||||
return {"history": []}
|
||||
gen_ids = get_generation_ids_for_file_ids(list(all_file_ids))
|
||||
if not gen_ids:
|
||||
return {"history": []}
|
||||
gen_id_filter = gen_ids
|
||||
|
||||
if q and q.strip():
|
||||
if user.get("admin"):
|
||||
return {"history": search_history(q.strip(), limit=50)}
|
||||
return {"history": search_history_for_user(user["sub"], q.strip(), limit=50)}
|
||||
if user.get("admin"):
|
||||
return {"history": db_get_history(limit=50)}
|
||||
return {"history": get_history_for_user(user["sub"], limit=50)}
|
||||
rows = search_history(q.strip(), limit=50)
|
||||
else:
|
||||
rows = search_history_for_user(user["sub"], q.strip(), limit=50)
|
||||
else:
|
||||
if user.get("admin"):
|
||||
rows = db_get_history(limit=50)
|
||||
else:
|
||||
rows = get_history_for_user(user["sub"], limit=50)
|
||||
|
||||
if gen_id_filter is not None:
|
||||
gen_id_set = set(gen_id_filter)
|
||||
rows = [r for r in rows if r["id"] in gen_id_set]
|
||||
|
||||
# Annotate each row with detected_persons when person filtering is active
|
||||
if active_persons and rows:
|
||||
row_ids = [r["id"] for r in rows]
|
||||
file_ids_per_gen = get_file_ids_for_generation_ids(row_ids)
|
||||
all_annotate_ids = [fid for fids in file_ids_per_gen.values() for fid in fids]
|
||||
person_map = face_db_mod.get_persons_for_source_id_map(all_annotate_ids, "output") if all_annotate_ids else {}
|
||||
for row in rows:
|
||||
row["detected_persons"] = list({
|
||||
name
|
||||
for fid in file_ids_per_gen.get(row["id"], [])
|
||||
for name in person_map.get(fid, [])
|
||||
})
|
||||
|
||||
return {"history": rows}
|
||||
|
||||
|
||||
@router.get("/{prompt_id}/persons")
|
||||
async def get_generation_persons(prompt_id: str, user: dict = Depends(require_auth)):
|
||||
"""Return persons detected in the output images of a generation."""
|
||||
_assert_owns(prompt_id, user)
|
||||
import generation_db as gen_db
|
||||
import face_db as face_db_mod
|
||||
file_ids = gen_db.get_file_ids_for_prompt(prompt_id)
|
||||
persons = face_db_mod.get_persons_for_source_ids(file_ids, "output")
|
||||
return {"persons": persons}
|
||||
|
||||
|
||||
class _AddPersonRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.post("/{prompt_id}/persons")
|
||||
async def add_generation_person(
|
||||
prompt_id: str,
|
||||
body: _AddPersonRequest,
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""Manually tag a person to a generation's output files."""
|
||||
import generation_db as gen_db
|
||||
import face_db as face_db_mod
|
||||
|
||||
file_ids = gen_db.get_file_ids_for_prompt(prompt_id)
|
||||
if not file_ids:
|
||||
raise HTTPException(404, f"Generation {prompt_id!r} not found")
|
||||
|
||||
name = body.name.strip()
|
||||
if not name:
|
||||
raise HTTPException(400, "Name cannot be empty")
|
||||
if len(name) > 100:
|
||||
raise HTTPException(400, "Name too long (max 100 chars)")
|
||||
|
||||
person_id, _ = face_db_mod.get_or_create_person(name)
|
||||
|
||||
# Only add if not already tagged on any file of this generation
|
||||
existing = face_db_mod.get_persons_for_source_ids(file_ids, "output")
|
||||
if not any(p["id"] == person_id for p in existing):
|
||||
face_db_mod.insert_detection(
|
||||
source_type="output",
|
||||
source_id=file_ids[0],
|
||||
embedding=None,
|
||||
bbox={},
|
||||
face_index=0,
|
||||
person_id=person_id,
|
||||
)
|
||||
|
||||
return {"person_id": person_id, "name": name}
|
||||
|
||||
|
||||
@router.delete("/{prompt_id}/persons/{person_id}")
|
||||
async def remove_generation_person(
|
||||
prompt_id: str,
|
||||
person_id: int,
|
||||
_: dict = Depends(require_admin),
|
||||
):
|
||||
"""Remove all face detections linking a person to this generation's output files."""
|
||||
import generation_db as gen_db
|
||||
import face_db as face_db_mod
|
||||
|
||||
file_ids = gen_db.get_file_ids_for_prompt(prompt_id)
|
||||
if not file_ids:
|
||||
raise HTTPException(404, f"Generation {prompt_id!r} not found")
|
||||
|
||||
face_db_mod.remove_person_from_source_ids(file_ids, person_id, "output")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/{prompt_id}/images")
|
||||
@@ -101,6 +216,11 @@ async def get_history_file(
|
||||
start = int(start_str) if start_str else 0
|
||||
end = int(end_str) if end_str else total - 1
|
||||
end = min(end, total - 1)
|
||||
if start < 0 or start > end:
|
||||
return Response(
|
||||
status_code=416,
|
||||
headers={"Content-Range": f"bytes */{total}"},
|
||||
)
|
||||
chunk = data[start : end + 1]
|
||||
return Response(
|
||||
content=chunk,
|
||||
@@ -110,6 +230,7 @@ async def get_history_file(
|
||||
"Content-Range": f"bytes {start}-{end}/{total}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(len(chunk)),
|
||||
"Cache-Control": "private",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -119,25 +240,62 @@ async def get_history_file(
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(total),
|
||||
"Cache-Control": "private",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class _ShareRequest(BaseModel):
|
||||
is_public: bool = False
|
||||
expires_in_hours: Optional[float] = None
|
||||
max_views: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/{prompt_id}/share")
|
||||
async def create_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
|
||||
async def create_generation_share(
|
||||
prompt_id: str,
|
||||
body: _ShareRequest = Body(default_factory=_ShareRequest),
|
||||
user: dict = Depends(require_auth),
|
||||
):
|
||||
"""Create a share token for a generation. Only the owner may share."""
|
||||
# Use the same 404-for-everything helper to avoid leaking prompt_id existence
|
||||
_assert_owns(prompt_id, user)
|
||||
from generation_db import create_share as db_create_share
|
||||
token = db_create_share(prompt_id, user["sub"])
|
||||
return {"share_token": token}
|
||||
from generation_db import (
|
||||
create_share as db_create_share,
|
||||
get_active_share_for_prompt as db_get_active_share,
|
||||
)
|
||||
existing = db_get_active_share(prompt_id, user["sub"])
|
||||
if existing:
|
||||
raise HTTPException(409, "A share already exists — revoke it first")
|
||||
if body.is_public and body.expires_in_hours is None and body.max_views is None:
|
||||
raise HTTPException(400, "Public shares must have at least one expiry condition (time or view limit)")
|
||||
if body.max_views is not None and body.max_views < 1:
|
||||
raise HTTPException(400, "max_views must be >= 1")
|
||||
if body.expires_in_hours is not None and body.expires_in_hours <= 0:
|
||||
raise HTTPException(400, "expires_in_hours must be > 0")
|
||||
expires_at = None
|
||||
if body.expires_in_hours is not None:
|
||||
expires_at = (datetime.now(timezone.utc) + timedelta(hours=body.expires_in_hours)).isoformat()
|
||||
share = db_create_share(
|
||||
prompt_id, user["sub"],
|
||||
is_public=body.is_public,
|
||||
expires_at=expires_at,
|
||||
max_views=body.max_views,
|
||||
)
|
||||
return {
|
||||
"share_token": share["share_token"],
|
||||
"is_public": bool(share["is_public"]),
|
||||
"expires_at": share["expires_at"],
|
||||
"max_views": share["max_views"],
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{prompt_id}/share")
|
||||
async def revoke_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
|
||||
"""Revoke a share token for a generation. Only the owner may revoke."""
|
||||
"""Revoke a share token for a generation. Admins can revoke any share."""
|
||||
from generation_db import revoke_share as db_revoke_share
|
||||
deleted = db_revoke_share(prompt_id, user["sub"])
|
||||
# Admins pass None for owner_label to delete by prompt_id alone
|
||||
owner_label = None if user.get("admin") else user["sub"]
|
||||
deleted = db_revoke_share(prompt_id, owner_label)
|
||||
if not deleted:
|
||||
raise HTTPException(404, "No active share found for this generation")
|
||||
return {"ok": True}
|
||||
|
||||
@@ -3,10 +3,11 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from web.auth import require_auth
|
||||
@@ -15,10 +16,38 @@ from web.deps import get_config, get_user_registry
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
||||
_MAX_UPLOAD_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
_SAFE_SLOT_RE = re.compile(r'^[a-zA-Z0-9_\-]+$')
|
||||
|
||||
|
||||
def _validate_slot_key(slot_key: str) -> None:
|
||||
if not _SAFE_SLOT_RE.match(slot_key):
|
||||
raise HTTPException(400, "slot_key may only contain letters, digits, hyphens, underscores")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_inputs(_: dict = Depends(require_auth)):
|
||||
"""List all input images (Discord + web uploads)."""
|
||||
async def list_inputs(
|
||||
_: dict = Depends(require_auth),
|
||||
persons: list[str] = Query(default=[], alias="persons", description="Filter by person name/alias substring (repeatable)"),
|
||||
):
|
||||
"""List all input images (Discord + web uploads). Optionally filter by persons."""
|
||||
active_persons = [p.strip() for p in persons if p.strip()]
|
||||
if active_persons:
|
||||
import face_db as face_db_mod
|
||||
from input_image_db import get_images_by_ids
|
||||
all_ids: set[int] = set()
|
||||
for p in active_persons:
|
||||
ids = face_db_mod.get_source_ids_for_person_query(p, "input")
|
||||
all_ids.update(ids)
|
||||
images = list(get_images_by_ids(list(all_ids))) if all_ids else []
|
||||
if images:
|
||||
person_map = face_db_mod.get_persons_for_source_id_map(
|
||||
[img["id"] for img in images], "input"
|
||||
)
|
||||
for img in images:
|
||||
img["detected_persons"] = person_map.get(img["id"], [])
|
||||
return images
|
||||
from input_image_db import get_all_images
|
||||
rows = get_all_images()
|
||||
return [dict(r) for r in rows]
|
||||
@@ -44,8 +73,16 @@ async def upload_input(
|
||||
if config is None:
|
||||
raise HTTPException(503, "Config not available")
|
||||
|
||||
if slot_key:
|
||||
_validate_slot_key(slot_key)
|
||||
|
||||
data = await file.read()
|
||||
filename = file.filename or "upload.png"
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext not in _ALLOWED_EXTS:
|
||||
raise HTTPException(415, f"Unsupported file type '{ext}'. Allowed: {sorted(_ALLOWED_EXTS)}")
|
||||
if len(data) > _MAX_UPLOAD_BYTES:
|
||||
raise HTTPException(413, "File too large (max 50 MB)")
|
||||
|
||||
from input_image_db import upsert_image, activate_image_for_slot
|
||||
row_id = upsert_image(
|
||||
@@ -72,7 +109,30 @@ async def upload_input(
|
||||
if comfy:
|
||||
comfy.state_manager.set_override(slot_key, activated_filename)
|
||||
|
||||
return {"id": row_id, "filename": filename, "slot_key": slot_key, "activated_filename": activated_filename}
|
||||
# Face scan — runs synchronously here (~1-2 s); unknown faces returned to UI
|
||||
pending_faces: list[dict] = []
|
||||
try:
|
||||
from face_service import get_face_service
|
||||
import face_db as _face_db
|
||||
_face_db.init_db()
|
||||
svc = get_face_service()
|
||||
if svc.available:
|
||||
results = await svc.scan_input_image(row_id, data)
|
||||
pending_faces = [
|
||||
{"detection_id": r.detection_id, "face_index": r.face_index, "bbox": r.bbox}
|
||||
for r in results
|
||||
if r.matched_person_id is None
|
||||
]
|
||||
except Exception as exc:
|
||||
logger.warning("Face scan failed for upload row_id=%d: %s", row_id, exc)
|
||||
|
||||
return {
|
||||
"id": row_id,
|
||||
"filename": filename,
|
||||
"slot_key": slot_key,
|
||||
"activated_filename": activated_filename,
|
||||
"pending_faces": pending_faces,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{row_id}/activate")
|
||||
@@ -92,6 +152,7 @@ async def activate_input(
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
user_label: str = user["sub"]
|
||||
_validate_slot_key(slot_key)
|
||||
namespaced_key = f"{user_label}_{slot_key}"
|
||||
|
||||
try:
|
||||
|
||||
@@ -2,28 +2,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
from web.auth import require_auth
|
||||
from web.auth import optional_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _auth_gate(gen: dict, user: Optional[dict]) -> None:
|
||||
"""Raise 401 if the share is private and the user is not authenticated."""
|
||||
if not gen.get("is_public") and user is None:
|
||||
raise HTTPException(401, "Authentication required")
|
||||
|
||||
|
||||
@router.get("/{token}")
|
||||
async def get_share(token: str, _: dict = Depends(require_auth)):
|
||||
"""Fetch share metadata and images. Any authenticated user may view a valid share link."""
|
||||
async def get_share(token: str, user: Optional[dict] = Depends(optional_auth)):
|
||||
"""Fetch share metadata and images. Public shares require no login; private shares require auth."""
|
||||
from generation_db import get_share_by_token, get_files
|
||||
gen = get_share_by_token(token)
|
||||
if gen is None:
|
||||
raise HTTPException(404, "Share not found or revoked")
|
||||
raise HTTPException(404, "Share not found, expired, or revoked")
|
||||
_auth_gate(gen, user)
|
||||
files = get_files(gen["prompt_id"])
|
||||
return {
|
||||
"prompt_id": gen["prompt_id"],
|
||||
"created_at": gen["created_at"],
|
||||
"overrides": gen["overrides"],
|
||||
"seed": gen["seed"],
|
||||
"is_public": bool(gen["is_public"]),
|
||||
"expires_at": gen["expires_at"],
|
||||
"max_views": gen["max_views"],
|
||||
"view_count": gen["view_count"],
|
||||
"images": [
|
||||
{
|
||||
"filename": f["filename"],
|
||||
@@ -40,13 +52,14 @@ async def get_share_file(
|
||||
token: str,
|
||||
filename: str,
|
||||
request: Request,
|
||||
_: dict = Depends(require_auth),
|
||||
user: Optional[dict] = Depends(optional_auth),
|
||||
):
|
||||
"""Stream a single output file via share token, with HTTP range support for video seeking."""
|
||||
from generation_db import get_share_by_token, get_files
|
||||
gen = get_share_by_token(token)
|
||||
from generation_db import get_share_meta, get_files
|
||||
gen = get_share_meta(token)
|
||||
if gen is None:
|
||||
raise HTTPException(404, "Share not found or revoked")
|
||||
raise HTTPException(404, "Share not found, expired, or revoked")
|
||||
_auth_gate(gen, user)
|
||||
files = get_files(gen["prompt_id"])
|
||||
matched = next((f for f in files if f["filename"] == filename), None)
|
||||
if matched is None:
|
||||
@@ -63,6 +76,11 @@ async def get_share_file(
|
||||
start = int(start_str) if start_str else 0
|
||||
end = int(end_str) if end_str else total - 1
|
||||
end = min(end, total - 1)
|
||||
if start < 0 or start > end:
|
||||
return Response(
|
||||
status_code=416,
|
||||
headers={"Content-Range": f"bytes */{total}"},
|
||||
)
|
||||
chunk = data[start : end + 1]
|
||||
return Response(
|
||||
content=chunk,
|
||||
@@ -72,6 +90,7 @@ async def get_share_file(
|
||||
"Content-Range": f"bytes {start}-{end}/{total}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(len(chunk)),
|
||||
"Cache-Control": "public, max-age=3600" if gen.get("is_public") else "private",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -81,5 +100,6 @@ async def get_share_file(
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(total),
|
||||
"Cache-Control": "public, max-age=3600" if gen.get("is_public") else "private",
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user