"""POST /api/generate and /api/workflow-gen""" from __future__ import annotations import asyncio import logging from pathlib import Path from typing import Any, Dict, Optional from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from web.auth import require_auth from web.deps import get_comfy, get_config, get_user_registry from web.ws_bus import get_bus router = APIRouter() logger = logging.getLogger(__name__) class GenerateRequest(BaseModel): prompt: str negative_prompt: Optional[str] = None overrides: Optional[Dict[str, Any]] = None # extra per-request overrides class WorkflowGenRequest(BaseModel): count: int = 1 overrides: Optional[Dict[str, Any]] = None # per-request overrides (merged with state) @router.post("/generate") async def generate(body: GenerateRequest, user: dict = Depends(require_auth)): """Submit a prompt-based generation to ComfyUI.""" comfy = get_comfy() if comfy is None: raise HTTPException(503, "ComfyUI client not available") user_label: str = user["sub"] bus = get_bus() registry = get_user_registry() # Temporary seed override from request if body.overrides and "seed" in body.overrides: seed_override = body.overrides["seed"] elif registry: seed_override = registry.get_state_manager(user_label).get_seed() else: seed_override = comfy.state_manager.get_seed() overrides_for_gen = {"prompt": body.prompt} if body.negative_prompt: overrides_for_gen["negative_prompt"] = body.negative_prompt if seed_override is not None: overrides_for_gen["seed"] = seed_override # Also apply any extra per-request overrides if body.overrides: overrides_for_gen.update(body.overrides) # Get queue position estimate depth = await comfy.get_queue_depth() # Start generation as background task so we can return the prompt_id immediately prompt_id_holder: list = [] async def _run(): # Use the user's own workflow template if registry: template = registry.get_workflow_template(user_label) else: template = comfy.workflow_manager.get_workflow_template() if not template: await bus.broadcast_to_user(user_label, "generation_error", { "prompt_id": None, "error": "No workflow template loaded" }) return import uuid pid = str(uuid.uuid4()) prompt_id_holder.append(pid) def on_progress(node, pid_): asyncio.create_task(bus.broadcast("node_executing", { "node": node, "prompt_id": pid_ })) workflow, applied = comfy.inspector.inject_overrides(template, overrides_for_gen) seed_used = applied.get("seed") comfy.last_seed = seed_used try: images, videos = await comfy._general_generate(workflow, pid, on_progress) except Exception as exc: logger.exception("Generation error for prompt %s", pid) await bus.broadcast_to_user(user_label, "generation_error", { "prompt_id": pid, "error": str(exc) }) return comfy.last_prompt_id = pid comfy.total_generated += 1 # Persist to DB before flush_pending deletes local files config = get_config() try: from generation_db import record_generation, record_file gen_id = record_generation(pid, "web", user_label, overrides_for_gen, seed_used) for i, img_data in enumerate(images): record_file(gen_id, f"image_{i:04d}.png", img_data) if config and videos: for vid in videos: vsub = vid.get("video_subfolder", "") vname = vid.get("video_name", "") vpath = ( Path(config.comfy_output_path) / vsub / vname if vsub else Path(config.comfy_output_path) / vname ) try: record_file(gen_id, vname, vpath.read_bytes()) except OSError: pass except Exception as exc: logger.warning("Failed to record generation to DB: %s", exc) # Flush auto-upload if config: from media_uploader import flush_pending asyncio.create_task(flush_pending( Path(config.comfy_output_path), config.media_upload_user, config.media_upload_pass, )) await bus.broadcast("queue_update", { "prompt_id": pid, "status": "complete", }) await bus.broadcast_to_user(user_label, "generation_complete", { "prompt_id": pid, "seed": seed_used, "image_count": len(images), "video_count": len(videos), }) asyncio.create_task(_run()) return { "queued": True, "queue_position": depth + 1, "message": "Generation submitted to ComfyUI", } @router.post("/workflow-gen") async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_auth)): """Submit workflow-based generation(s) to ComfyUI.""" comfy = get_comfy() if comfy is None: raise HTTPException(503, "ComfyUI client not available") user_label: str = user["sub"] bus = get_bus() registry = get_user_registry() count = max(1, min(body.count, 20)) # cap at 20 async def _run_one(): # Use the user's own state and template if registry: user_sm = registry.get_state_manager(user_label) user_template = registry.get_workflow_template(user_label) else: user_sm = comfy.state_manager user_template = comfy.workflow_manager.get_workflow_template() if not user_template: await bus.broadcast_to_user(user_label, "generation_error", { "prompt_id": None, "error": "No workflow template loaded" }) return overrides = user_sm.get_overrides() if body.overrides: overrides = {**overrides, **body.overrides} import uuid pid = str(uuid.uuid4()) def on_progress(node, pid_): asyncio.create_task(bus.broadcast("node_executing", { "node": node, "prompt_id": pid_ })) workflow, applied = comfy.inspector.inject_overrides(user_template, overrides) seed_used = applied.get("seed") comfy.last_seed = seed_used try: images, videos = await comfy._general_generate(workflow, pid, on_progress) except Exception as exc: logger.exception("Workflow gen error") await bus.broadcast_to_user(user_label, "generation_error", { "prompt_id": None, "error": str(exc) }) return comfy.last_prompt_id = pid comfy.total_generated += 1 config = get_config() try: from generation_db import record_generation, record_file gen_id = record_generation(pid, "web", user_label, overrides, seed_used) for i, img_data in enumerate(images): record_file(gen_id, f"image_{i:04d}.png", img_data) if config and videos: for vid in videos: vsub = vid.get("video_subfolder", "") vname = vid.get("video_name", "") vpath = ( Path(config.comfy_output_path) / vsub / vname if vsub else Path(config.comfy_output_path) / vname ) try: record_file(gen_id, vname, vpath.read_bytes()) except OSError: pass except Exception as exc: logger.warning("Failed to record generation to DB: %s", exc) if config: from media_uploader import flush_pending asyncio.create_task(flush_pending( Path(config.comfy_output_path), config.media_upload_user, config.media_upload_pass, )) await bus.broadcast("queue_update", {"prompt_id": pid, "status": "complete"}) await bus.broadcast_to_user(user_label, "generation_complete", { "prompt_id": pid, "seed": seed_used, "image_count": len(images), "video_count": len(videos), }) depth = await comfy.get_queue_depth() for _ in range(count): asyncio.create_task(_run_one()) return { "queued": True, "count": count, "queue_position": depth + 1, }