""" comfy_client.py ================ Asynchronous client for the ComfyUI API. Wraps ComfyUI's REST and WebSocket endpoints. Workflow template injection is now handled by :class:`~workflow_inspector.WorkflowInspector`, so this class only needs to: 1. Accept a workflow template (delegated to WorkflowManager). 2. Accept runtime overrides (delegated to WorkflowStateManager). 3. Build the final workflow via inspector.inject_overrides(). 4. Queue it to ComfyUI, wait for completion via WebSocket, fetch outputs. A ``{prompt_id: callback}`` map is maintained for future WebSocket broadcasting (web UI phase). Discord commands still use the synchronous await-and-return model. """ from __future__ import annotations import asyncio import json import logging import uuid from typing import Any, Callable, Dict, List, Optional import aiohttp import websockets from workflow_inspector import WorkflowInspector class ComfyClient: """ Asynchronous ComfyUI client. Parameters ---------- server_address : str ``hostname:port`` of the ComfyUI server. workflow_manager : WorkflowManager Template storage (injected). state_manager : WorkflowStateManager Runtime overrides (injected). logger : Optional[logging.Logger] Logger for debug/info messages. history_limit : int Max recent generations to keep in the in-memory deque. """ def __init__( self, server_address: str, workflow_manager, state_manager, logger: Optional[logging.Logger] = None, *, history_limit: int = 10, output_path: Optional[str] = None, ) -> None: self.server_address = server_address.strip().rstrip("/") self.client_id = str(uuid.uuid4()) self._session: Optional[aiohttp.ClientSession] = None self.protocol = "http" self.ws_protocol = "ws" self.workflow_manager = workflow_manager self.state_manager = state_manager self.inspector = WorkflowInspector() self.output_path = output_path # prompt_id → asyncio.Future for web-UI broadcast (Phase 4) self._pending_callbacks: Dict[str, Callable] = {} from collections import deque self._history = deque(maxlen=history_limit) self.last_prompt_id: Optional[str] = None self.last_seed: Optional[int] = None self.total_generated: int = 0 self.logger = logger if logger else logging.getLogger(__name__) # ------------------------------------------------------------------ # Session # ------------------------------------------------------------------ @property def session(self) -> aiohttp.ClientSession: """Lazily create and return an aiohttp session.""" if self._session is None or self._session.closed: self._session = aiohttp.ClientSession() return self._session # ------------------------------------------------------------------ # Low-level REST helpers # ------------------------------------------------------------------ async def _queue_prompt( self, prompt: dict[str, Any], prompt_id: str, ws_client_id: str | None = None, ) -> dict[str, Any]: """Submit a workflow to the ComfyUI queue.""" payload = { "prompt": prompt, "client_id": ws_client_id if ws_client_id is not None else self.client_id, "prompt_id": prompt_id, } url = f"{self.protocol}://{self.server_address}/prompt" async with self.session.post(url, json=payload, headers={"Content-Type": "application/json"}) as resp: resp.raise_for_status() return await resp.json() async def _wait_for_execution( self, prompt_id: str, on_progress: Optional[Callable[[str, str], None]] = None, ws_client_id: str | None = None, ) -> None: """ Wait for a queued prompt to finish executing via WebSocket. Parameters ---------- prompt_id : str The prompt to wait for. on_progress : Optional[Callable[[str, str], None]] Called with ``(node_id, prompt_id)`` for each ``node_executing`` event. Pass ``None`` for Discord commands (no web broadcast needed). """ client_id = ws_client_id if ws_client_id is not None else self.client_id ws_url = ( f"{self.ws_protocol}://{self.server_address}/ws" f"?clientId={client_id}" ) async with websockets.connect(ws_url) as ws: try: while True: out = await ws.recv() if not isinstance(out, str): continue message = json.loads(out) mtype = message.get("type") if mtype == "executing": data = message["data"] node = data.get("node") if node: self.logger.debug("Executing node: %s", node) if on_progress and data.get("prompt_id") == prompt_id: try: on_progress(node, prompt_id) except Exception: pass if data["node"] is None and data.get("prompt_id") == prompt_id: self.logger.info("Execution complete for prompt %s", prompt_id) break elif mtype == "execution_success": if message.get("data", {}).get("prompt_id") == prompt_id: self.logger.info("execution_success for prompt %s", prompt_id) break elif mtype == "execution_error": if message.get("data", {}).get("prompt_id") == prompt_id: error = message.get("data", {}).get("exception_message", "unknown error") raise RuntimeError(f"ComfyUI execution error: {error}") except Exception as exc: self.logger.error("Error during execution wait: %s", exc) raise async def _get_history(self, prompt_id: str) -> dict[str, Any]: """Retrieve execution history for a given prompt id.""" url = f"{self.protocol}://{self.server_address}/history/{prompt_id}" async with self.session.get(url) as resp: resp.raise_for_status() return await resp.json() async def _download_image(self, filename: str, subfolder: str, folder_type: str) -> bytes: """Download an image from ComfyUI and return raw bytes.""" url = f"{self.protocol}://{self.server_address}/view" params = {"filename": filename, "subfolder": subfolder, "type": folder_type} async with self.session.get(url, params=params) as resp: resp.raise_for_status() return await resp.read() # ------------------------------------------------------------------ # Core generation pipeline # ------------------------------------------------------------------ async def _general_generate( self, workflow: dict[str, Any], prompt_id: str, on_progress: Optional[Callable[[str, str], None]] = None, ) -> tuple[List[bytes], List[dict[str, Any]]]: """ Queue a workflow, wait for it to execute, then collect outputs. Returns ------- tuple[List[bytes], List[dict]] ``(images, videos)`` — images as raw bytes, videos as info dicts. """ ws_client_id = str(uuid.uuid4()) await self._queue_prompt(workflow, prompt_id, ws_client_id) try: await self._wait_for_execution(prompt_id, on_progress=on_progress, ws_client_id=ws_client_id) except Exception: self.logger.error("Execution failed for prompt %s", prompt_id) return [], [] history = await self._get_history(prompt_id) if not history: self.logger.warning("No history for prompt %s", prompt_id) return [], [] images: List[bytes] = [] videos: List[dict[str, Any]] = [] for node_output in history.get(prompt_id, {}).get("outputs", {}).values(): for image_info in node_output.get("images", []): name = image_info["filename"] if name.rsplit(".", 1)[-1].lower() in {"mp4", "webm", "avi"}: videos.append({ "video_name": name, "video_subfolder": image_info.get("subfolder", ""), "video_type": image_info.get("type", "output"), }) else: data = await self._download_image( name, image_info["subfolder"], image_info["type"] ) images.append(data) return images, videos # ------------------------------------------------------------------ # DB persistence helper # ------------------------------------------------------------------ def _record_to_db( self, prompt_id: str, source: str, user_label: Optional[str], overrides: Dict[str, Any], seed: Optional[int], images: List[bytes], videos: List[Dict[str, Any]], ) -> None: """Persist generation metadata and file blobs to SQLite. Never raises.""" try: import generation_db from pathlib import Path as _Path gen_id = generation_db.record_generation( prompt_id, source, user_label, overrides, seed ) for i, img_data in enumerate(images): generation_db.record_file(gen_id, f"image_{i:04d}.png", img_data) if videos and self.output_path: for vid in videos: vname = vid.get("video_name", "") vsub = vid.get("video_subfolder", "") vpath = ( _Path(self.output_path) / vsub / vname if vsub else _Path(self.output_path) / vname ) try: generation_db.record_file(gen_id, vname, vpath.read_bytes()) except OSError as exc: self.logger.warning( "Could not read video for DB storage: %s: %s", vpath, exc ) except Exception as exc: self.logger.warning("Failed to record generation to DB: %s", exc) # ------------------------------------------------------------------ # Public generation API # ------------------------------------------------------------------ async def generate_image( self, prompt: str, negative_prompt: Optional[str] = None, on_progress: Optional[Callable[[str, str], None]] = None, *, source: str = "discord", user_label: Optional[str] = None, ) -> tuple[List[bytes], str]: """ Generate images using the current workflow template with a text prompt. Injects *prompt* (and optionally *negative_prompt*) via the inspector, plus any currently pinned seed from the state manager. All other overrides in the state manager are **not** applied here — use :meth:`generate_image_with_workflow` for the full override set. Parameters ---------- prompt : str Positive prompt text. negative_prompt : Optional[str] Negative prompt text (optional). on_progress : Optional[Callable] Called with ``(node_id, prompt_id)`` for each executing node. Returns ------- tuple[List[bytes], str] ``(images, prompt_id)`` """ template = self.workflow_manager.get_workflow_template() if not template: self.logger.warning("No workflow template set; cannot generate.") return [], "" overrides: Dict[str, Any] = {"prompt": prompt} if negative_prompt is not None: overrides["negative_prompt"] = negative_prompt # Respect pinned seed from state manager seed_pin = self.state_manager.get_seed() if seed_pin is not None: overrides["seed"] = seed_pin workflow, applied = self.inspector.inject_overrides(template, overrides) seed_used = applied.get("seed") self.last_seed = seed_used prompt_id = str(uuid.uuid4()) images, _videos = await self._general_generate(workflow, prompt_id, on_progress) self.last_prompt_id = prompt_id self.total_generated += 1 self._history.append({ "prompt_id": prompt_id, "prompt": prompt, "negative_prompt": negative_prompt, "seed": seed_used, }) self._record_to_db( prompt_id, source, user_label, {"prompt": prompt, "negative_prompt": negative_prompt}, seed_used, images, [], ) return images, prompt_id async def generate_image_with_workflow( self, on_progress: Optional[Callable[[str, str], None]] = None, *, source: str = "discord", user_label: Optional[str] = None, ) -> tuple[List[bytes], List[dict[str, Any]], str]: """ Generate images/videos from the current workflow applying ALL overrides stored in the state manager. Returns ------- tuple[List[bytes], List[dict], str] ``(images, videos, prompt_id)`` """ template = self.workflow_manager.get_workflow_template() prompt_id = str(uuid.uuid4()) if not template: self.logger.error("No workflow template set") return [], [], prompt_id overrides = self.state_manager.get_overrides() workflow, applied = self.inspector.inject_overrides(template, overrides) seed_used = applied.get("seed") self.last_seed = seed_used images, videos = await self._general_generate(workflow, prompt_id, on_progress) self.last_prompt_id = prompt_id self.total_generated += 1 prompt_str = overrides.get("prompt") or "" neg_str = overrides.get("negative_prompt") or "" self._history.append({ "prompt_id": prompt_id, "prompt": (prompt_str[:10] + "…") if len(prompt_str) > 10 else prompt_str or None, "negative_prompt": (neg_str[:10] + "…") if len(neg_str) > 10 else neg_str or None, "seed": seed_used, }) self._record_to_db(prompt_id, source, user_label, overrides, seed_used, images, videos) return images, videos, prompt_id # ------------------------------------------------------------------ # Workflow template management # ------------------------------------------------------------------ def set_workflow(self, workflow: dict[str, Any]) -> None: """Set the workflow template and clear all state overrides.""" self.workflow_manager.set_workflow_template(workflow) self.state_manager.clear_overrides() def load_workflow_from_file(self, path: str) -> None: """ Load a workflow template from a JSON file. Also clears state overrides and records the filename in the state manager for auto-load on restart. """ import json as _json with open(path, "r", encoding="utf-8") as f: workflow = _json.load(f) self.workflow_manager.set_workflow_template(workflow) self.state_manager.clear_overrides() from pathlib import Path self.state_manager.set_last_workflow_file(Path(path).name) def get_workflow_template(self) -> Optional[dict[str, Any]]: """Return the current workflow template (or None).""" return self.workflow_manager.get_workflow_template() # ------------------------------------------------------------------ # State management convenience wrappers # ------------------------------------------------------------------ def get_workflow_current_changes(self) -> dict[str, Any]: """Return all current overrides (backward-compat).""" return self.state_manager.get_changes() def set_workflow_current_changes(self, changes: dict[str, Any]) -> None: """Merge override changes (backward-compat).""" self.state_manager.set_changes(changes, merge=True) def set_workflow_current_prompt(self, prompt: str) -> None: self.state_manager.set_prompt(prompt) def set_workflow_current_negative_prompt(self, negative_prompt: str) -> None: self.state_manager.set_negative_prompt(negative_prompt) def set_workflow_current_input_image(self, input_image: str) -> None: self.state_manager.set_input_image(input_image) def get_current_workflow_prompt(self) -> Optional[str]: return self.state_manager.get_prompt() def get_current_workflow_negative_prompt(self) -> Optional[str]: return self.state_manager.get_negative_prompt() def get_current_workflow_input_image(self) -> Optional[str]: return self.state_manager.get_input_image() # ------------------------------------------------------------------ # Image upload # ------------------------------------------------------------------ async def upload_image( self, data: bytes, filename: str, *, image_type: str = "input", overwrite: bool = False, ) -> dict[str, Any]: """Upload an image to ComfyUI via the /upload/image endpoint.""" url = f"{self.protocol}://{self.server_address}/upload/image" form = aiohttp.FormData() form.add_field("image", data, filename=filename, content_type="application/octet-stream") form.add_field("type", image_type) form.add_field("overwrite", str(overwrite).lower()) async with self.session.post(url, data=form) as resp: resp.raise_for_status() try: return await resp.json() except aiohttp.ContentTypeError: return {"status": await resp.text()} # ------------------------------------------------------------------ # History # ------------------------------------------------------------------ def get_history(self) -> List[dict]: """Return a list of recently generated prompt records (from DB).""" try: from generation_db import get_history as db_get_history return db_get_history(limit=self._history.maxlen or 50) except Exception: return list(self._history) async def fetch_history_images(self, prompt_id: str) -> List[bytes]: """Re-download images for a previously generated prompt.""" history = await self._get_history(prompt_id) images: List[bytes] = [] for node_output in history.get(prompt_id, {}).get("outputs", {}).values(): for image_info in node_output.get("images", []): data = await self._download_image( image_info["filename"], image_info["subfolder"], image_info["type"], ) images.append(data) return images # ------------------------------------------------------------------ # Server info / queue # ------------------------------------------------------------------ async def get_system_stats(self) -> Optional[dict[str, Any]]: """Fetch ComfyUI system stats (/system_stats).""" try: url = f"{self.protocol}://{self.server_address}/system_stats" async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp: resp.raise_for_status() return await resp.json() except Exception as exc: self.logger.warning("Failed to fetch system stats: %s", exc) return None async def get_comfy_queue(self) -> Optional[dict[str, Any]]: """Fetch the ComfyUI queue (/queue).""" try: url = f"{self.protocol}://{self.server_address}/queue" async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp: resp.raise_for_status() return await resp.json() except Exception as exc: self.logger.warning("Failed to fetch comfy queue: %s", exc) return None async def get_queue_depth(self) -> int: """Return the total number of pending + running jobs in ComfyUI.""" q = await self.get_comfy_queue() if q: return len(q.get("queue_running", [])) + len(q.get("queue_pending", [])) return 0 async def clear_queue(self) -> bool: """Clear all pending jobs from the ComfyUI queue.""" try: url = f"{self.protocol}://{self.server_address}/queue" async with self.session.post( url, json={"clear": True}, headers={"Content-Type": "application/json"}, timeout=aiohttp.ClientTimeout(total=5), ) as resp: return resp.status in (200, 204) except Exception as exc: self.logger.warning("Failed to clear comfy queue: %s", exc) return False async def check_connection(self) -> bool: """Return True if the ComfyUI server is reachable.""" try: url = f"{self.protocol}://{self.server_address}/system_stats" async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp: return resp.status == 200 except Exception: return False async def get_models(self, model_type: str = "checkpoints") -> List[str]: """ Fetch available model names from ComfyUI. Parameters ---------- model_type : str One of ``"checkpoints"``, ``"loras"``, etc. """ try: url = f"{self.protocol}://{self.server_address}/object_info" async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp: resp.raise_for_status() info = await resp.json() if model_type == "checkpoints": node = info.get("CheckpointLoaderSimple", {}) return node.get("input", {}).get("required", {}).get("ckpt_name", [None])[0] or [] elif model_type == "loras": node = info.get("LoraLoader", {}) return node.get("input", {}).get("required", {}).get("lora_name", [None])[0] or [] return [] except Exception as exc: self.logger.warning("Failed to fetch models (%s): %s", model_type, exc) return [] # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------ async def close(self) -> None: """Close the underlying aiohttp session.""" if self._session and not self._session.closed: await self._session.close() async def __aenter__(self) -> "ComfyClient": return self async def __aexit__(self, exc_type, exc, tb) -> None: await self.close()