Initial commit — ComfyUI Discord bot + web UI

Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI
(React/TS/Vite/Tailwind), ComfyUI integration, generation history DB,
preset manager, workflow inspector, and all supporting modules.

Excluded from tracking: .env, invite_tokens.json, *.db (SQLite),
current-workflow-changes.json, user_settings/, presets/, logs/,
web-static/ (build output), frontend/node_modules/.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Khoa (Revenovich) Tran Gia
2026-03-02 09:55:48 +07:00
commit 1ed3c9ec4b
82 changed files with 20693 additions and 0 deletions

604
comfy_client.py Normal file
View File

@@ -0,0 +1,604 @@
"""
comfy_client.py
================
Asynchronous client for the ComfyUI API.
Wraps ComfyUI's REST and WebSocket endpoints. Workflow template injection
is now handled by :class:`~workflow_inspector.WorkflowInspector`, so this
class only needs to:
1. Accept a workflow template (delegated to WorkflowManager).
2. Accept runtime overrides (delegated to WorkflowStateManager).
3. Build the final workflow via inspector.inject_overrides().
4. Queue it to ComfyUI, wait for completion via WebSocket, fetch outputs.
A ``{prompt_id: callback}`` map is maintained for future WebSocket
broadcasting (web UI phase). Discord commands still use the synchronous
await-and-return model.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from typing import Any, Callable, Dict, List, Optional
import aiohttp
import websockets
from workflow_inspector import WorkflowInspector
class ComfyClient:
"""
Asynchronous ComfyUI client.
Parameters
----------
server_address : str
``hostname:port`` of the ComfyUI server.
workflow_manager : WorkflowManager
Template storage (injected).
state_manager : WorkflowStateManager
Runtime overrides (injected).
logger : Optional[logging.Logger]
Logger for debug/info messages.
history_limit : int
Max recent generations to keep in the in-memory deque.
"""
def __init__(
self,
server_address: str,
workflow_manager,
state_manager,
logger: Optional[logging.Logger] = None,
*,
history_limit: int = 10,
output_path: Optional[str] = None,
) -> None:
self.server_address = server_address.strip().rstrip("/")
self.client_id = str(uuid.uuid4())
self._session: Optional[aiohttp.ClientSession] = None
self.protocol = "http"
self.ws_protocol = "ws"
self.workflow_manager = workflow_manager
self.state_manager = state_manager
self.inspector = WorkflowInspector()
self.output_path = output_path
# prompt_id → asyncio.Future for web-UI broadcast (Phase 4)
self._pending_callbacks: Dict[str, Callable] = {}
from collections import deque
self._history = deque(maxlen=history_limit)
self.last_prompt_id: Optional[str] = None
self.last_seed: Optional[int] = None
self.total_generated: int = 0
self.logger = logger if logger else logging.getLogger(__name__)
# ------------------------------------------------------------------
# Session
# ------------------------------------------------------------------
@property
def session(self) -> aiohttp.ClientSession:
"""Lazily create and return an aiohttp session."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
# ------------------------------------------------------------------
# Low-level REST helpers
# ------------------------------------------------------------------
async def _queue_prompt(
self,
prompt: dict[str, Any],
prompt_id: str,
ws_client_id: str | None = None,
) -> dict[str, Any]:
"""Submit a workflow to the ComfyUI queue."""
payload = {
"prompt": prompt,
"client_id": ws_client_id if ws_client_id is not None else self.client_id,
"prompt_id": prompt_id,
}
url = f"{self.protocol}://{self.server_address}/prompt"
async with self.session.post(url, json=payload,
headers={"Content-Type": "application/json"}) as resp:
resp.raise_for_status()
return await resp.json()
async def _wait_for_execution(
self,
prompt_id: str,
on_progress: Optional[Callable[[str, str], None]] = None,
ws_client_id: str | None = None,
) -> None:
"""
Wait for a queued prompt to finish executing via WebSocket.
Parameters
----------
prompt_id : str
The prompt to wait for.
on_progress : Optional[Callable[[str, str], None]]
Called with ``(node_id, prompt_id)`` for each ``node_executing``
event. Pass ``None`` for Discord commands (no web broadcast
needed).
"""
client_id = ws_client_id if ws_client_id is not None else self.client_id
ws_url = (
f"{self.ws_protocol}://{self.server_address}/ws"
f"?clientId={client_id}"
)
async with websockets.connect(ws_url) as ws:
try:
while True:
out = await ws.recv()
if not isinstance(out, str):
continue
message = json.loads(out)
mtype = message.get("type")
if mtype == "executing":
data = message["data"]
node = data.get("node")
if node:
self.logger.debug("Executing node: %s", node)
if on_progress and data.get("prompt_id") == prompt_id:
try:
on_progress(node, prompt_id)
except Exception:
pass
if data["node"] is None and data.get("prompt_id") == prompt_id:
self.logger.info("Execution complete for prompt %s", prompt_id)
break
elif mtype == "execution_success":
if message.get("data", {}).get("prompt_id") == prompt_id:
self.logger.info("execution_success for prompt %s", prompt_id)
break
elif mtype == "execution_error":
if message.get("data", {}).get("prompt_id") == prompt_id:
error = message.get("data", {}).get("exception_message", "unknown error")
raise RuntimeError(f"ComfyUI execution error: {error}")
except Exception as exc:
self.logger.error("Error during execution wait: %s", exc)
raise
async def _get_history(self, prompt_id: str) -> dict[str, Any]:
"""Retrieve execution history for a given prompt id."""
url = f"{self.protocol}://{self.server_address}/history/{prompt_id}"
async with self.session.get(url) as resp:
resp.raise_for_status()
return await resp.json()
async def _download_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
"""Download an image from ComfyUI and return raw bytes."""
url = f"{self.protocol}://{self.server_address}/view"
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
async with self.session.get(url, params=params) as resp:
resp.raise_for_status()
return await resp.read()
# ------------------------------------------------------------------
# Core generation pipeline
# ------------------------------------------------------------------
async def _general_generate(
self,
workflow: dict[str, Any],
prompt_id: str,
on_progress: Optional[Callable[[str, str], None]] = None,
) -> tuple[List[bytes], List[dict[str, Any]]]:
"""
Queue a workflow, wait for it to execute, then collect outputs.
Returns
-------
tuple[List[bytes], List[dict]]
``(images, videos)`` — images as raw bytes, videos as info dicts.
"""
ws_client_id = str(uuid.uuid4())
await self._queue_prompt(workflow, prompt_id, ws_client_id)
try:
await self._wait_for_execution(prompt_id, on_progress=on_progress, ws_client_id=ws_client_id)
except Exception:
self.logger.error("Execution failed for prompt %s", prompt_id)
return [], []
history = await self._get_history(prompt_id)
if not history:
self.logger.warning("No history for prompt %s", prompt_id)
return [], []
images: List[bytes] = []
videos: List[dict[str, Any]] = []
for node_output in history.get(prompt_id, {}).get("outputs", {}).values():
for image_info in node_output.get("images", []):
name = image_info["filename"]
if name.rsplit(".", 1)[-1].lower() in {"mp4", "webm", "avi"}:
videos.append({
"video_name": name,
"video_subfolder": image_info.get("subfolder", ""),
"video_type": image_info.get("type", "output"),
})
else:
data = await self._download_image(
name, image_info["subfolder"], image_info["type"]
)
images.append(data)
return images, videos
# ------------------------------------------------------------------
# DB persistence helper
# ------------------------------------------------------------------
def _record_to_db(
self,
prompt_id: str,
source: str,
user_label: Optional[str],
overrides: Dict[str, Any],
seed: Optional[int],
images: List[bytes],
videos: List[Dict[str, Any]],
) -> None:
"""Persist generation metadata and file blobs to SQLite. Never raises."""
try:
import generation_db
from pathlib import Path as _Path
gen_id = generation_db.record_generation(
prompt_id, source, user_label, overrides, seed
)
for i, img_data in enumerate(images):
generation_db.record_file(gen_id, f"image_{i:04d}.png", img_data)
if videos and self.output_path:
for vid in videos:
vname = vid.get("video_name", "")
vsub = vid.get("video_subfolder", "")
vpath = (
_Path(self.output_path) / vsub / vname
if vsub
else _Path(self.output_path) / vname
)
try:
generation_db.record_file(gen_id, vname, vpath.read_bytes())
except OSError as exc:
self.logger.warning(
"Could not read video for DB storage: %s: %s", vpath, exc
)
except Exception as exc:
self.logger.warning("Failed to record generation to DB: %s", exc)
# ------------------------------------------------------------------
# Public generation API
# ------------------------------------------------------------------
async def generate_image(
self,
prompt: str,
negative_prompt: Optional[str] = None,
on_progress: Optional[Callable[[str, str], None]] = None,
*,
source: str = "discord",
user_label: Optional[str] = None,
) -> tuple[List[bytes], str]:
"""
Generate images using the current workflow template with a text prompt.
Injects *prompt* (and optionally *negative_prompt*) via the inspector,
plus any currently pinned seed from the state manager. All other
overrides in the state manager are **not** applied here — use
:meth:`generate_image_with_workflow` for the full override set.
Parameters
----------
prompt : str
Positive prompt text.
negative_prompt : Optional[str]
Negative prompt text (optional).
on_progress : Optional[Callable]
Called with ``(node_id, prompt_id)`` for each executing node.
Returns
-------
tuple[List[bytes], str]
``(images, prompt_id)``
"""
template = self.workflow_manager.get_workflow_template()
if not template:
self.logger.warning("No workflow template set; cannot generate.")
return [], ""
overrides: Dict[str, Any] = {"prompt": prompt}
if negative_prompt is not None:
overrides["negative_prompt"] = negative_prompt
# Respect pinned seed from state manager
seed_pin = self.state_manager.get_seed()
if seed_pin is not None:
overrides["seed"] = seed_pin
workflow, applied = self.inspector.inject_overrides(template, overrides)
seed_used = applied.get("seed")
self.last_seed = seed_used
prompt_id = str(uuid.uuid4())
images, _videos = await self._general_generate(workflow, prompt_id, on_progress)
self.last_prompt_id = prompt_id
self.total_generated += 1
self._history.append({
"prompt_id": prompt_id,
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed_used,
})
self._record_to_db(
prompt_id, source, user_label,
{"prompt": prompt, "negative_prompt": negative_prompt},
seed_used, images, [],
)
return images, prompt_id
async def generate_image_with_workflow(
self,
on_progress: Optional[Callable[[str, str], None]] = None,
*,
source: str = "discord",
user_label: Optional[str] = None,
) -> tuple[List[bytes], List[dict[str, Any]], str]:
"""
Generate images/videos from the current workflow applying ALL
overrides stored in the state manager.
Returns
-------
tuple[List[bytes], List[dict], str]
``(images, videos, prompt_id)``
"""
template = self.workflow_manager.get_workflow_template()
prompt_id = str(uuid.uuid4())
if not template:
self.logger.error("No workflow template set")
return [], [], prompt_id
overrides = self.state_manager.get_overrides()
workflow, applied = self.inspector.inject_overrides(template, overrides)
seed_used = applied.get("seed")
self.last_seed = seed_used
images, videos = await self._general_generate(workflow, prompt_id, on_progress)
self.last_prompt_id = prompt_id
self.total_generated += 1
prompt_str = overrides.get("prompt") or ""
neg_str = overrides.get("negative_prompt") or ""
self._history.append({
"prompt_id": prompt_id,
"prompt": (prompt_str[:10] + "") if len(prompt_str) > 10 else prompt_str or None,
"negative_prompt": (neg_str[:10] + "") if len(neg_str) > 10 else neg_str or None,
"seed": seed_used,
})
self._record_to_db(prompt_id, source, user_label, overrides, seed_used, images, videos)
return images, videos, prompt_id
# ------------------------------------------------------------------
# Workflow template management
# ------------------------------------------------------------------
def set_workflow(self, workflow: dict[str, Any]) -> None:
"""Set the workflow template and clear all state overrides."""
self.workflow_manager.set_workflow_template(workflow)
self.state_manager.clear_overrides()
def load_workflow_from_file(self, path: str) -> None:
"""
Load a workflow template from a JSON file.
Also clears state overrides and records the filename in the state
manager for auto-load on restart.
"""
import json as _json
with open(path, "r", encoding="utf-8") as f:
workflow = _json.load(f)
self.workflow_manager.set_workflow_template(workflow)
self.state_manager.clear_overrides()
from pathlib import Path
self.state_manager.set_last_workflow_file(Path(path).name)
def get_workflow_template(self) -> Optional[dict[str, Any]]:
"""Return the current workflow template (or None)."""
return self.workflow_manager.get_workflow_template()
# ------------------------------------------------------------------
# State management convenience wrappers
# ------------------------------------------------------------------
def get_workflow_current_changes(self) -> dict[str, Any]:
"""Return all current overrides (backward-compat)."""
return self.state_manager.get_changes()
def set_workflow_current_changes(self, changes: dict[str, Any]) -> None:
"""Merge override changes (backward-compat)."""
self.state_manager.set_changes(changes, merge=True)
def set_workflow_current_prompt(self, prompt: str) -> None:
self.state_manager.set_prompt(prompt)
def set_workflow_current_negative_prompt(self, negative_prompt: str) -> None:
self.state_manager.set_negative_prompt(negative_prompt)
def set_workflow_current_input_image(self, input_image: str) -> None:
self.state_manager.set_input_image(input_image)
def get_current_workflow_prompt(self) -> Optional[str]:
return self.state_manager.get_prompt()
def get_current_workflow_negative_prompt(self) -> Optional[str]:
return self.state_manager.get_negative_prompt()
def get_current_workflow_input_image(self) -> Optional[str]:
return self.state_manager.get_input_image()
# ------------------------------------------------------------------
# Image upload
# ------------------------------------------------------------------
async def upload_image(
self,
data: bytes,
filename: str,
*,
image_type: str = "input",
overwrite: bool = False,
) -> dict[str, Any]:
"""Upload an image to ComfyUI via the /upload/image endpoint."""
url = f"{self.protocol}://{self.server_address}/upload/image"
form = aiohttp.FormData()
form.add_field("image", data, filename=filename,
content_type="application/octet-stream")
form.add_field("type", image_type)
form.add_field("overwrite", str(overwrite).lower())
async with self.session.post(url, data=form) as resp:
resp.raise_for_status()
try:
return await resp.json()
except aiohttp.ContentTypeError:
return {"status": await resp.text()}
# ------------------------------------------------------------------
# History
# ------------------------------------------------------------------
def get_history(self) -> List[dict]:
"""Return a list of recently generated prompt records (from DB)."""
try:
from generation_db import get_history as db_get_history
return db_get_history(limit=self._history.maxlen or 50)
except Exception:
return list(self._history)
async def fetch_history_images(self, prompt_id: str) -> List[bytes]:
"""Re-download images for a previously generated prompt."""
history = await self._get_history(prompt_id)
images: List[bytes] = []
for node_output in history.get(prompt_id, {}).get("outputs", {}).values():
for image_info in node_output.get("images", []):
data = await self._download_image(
image_info["filename"],
image_info["subfolder"],
image_info["type"],
)
images.append(data)
return images
# ------------------------------------------------------------------
# Server info / queue
# ------------------------------------------------------------------
async def get_system_stats(self) -> Optional[dict[str, Any]]:
"""Fetch ComfyUI system stats (/system_stats)."""
try:
url = f"{self.protocol}://{self.server_address}/system_stats"
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
resp.raise_for_status()
return await resp.json()
except Exception as exc:
self.logger.warning("Failed to fetch system stats: %s", exc)
return None
async def get_comfy_queue(self) -> Optional[dict[str, Any]]:
"""Fetch the ComfyUI queue (/queue)."""
try:
url = f"{self.protocol}://{self.server_address}/queue"
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
resp.raise_for_status()
return await resp.json()
except Exception as exc:
self.logger.warning("Failed to fetch comfy queue: %s", exc)
return None
async def get_queue_depth(self) -> int:
"""Return the total number of pending + running jobs in ComfyUI."""
q = await self.get_comfy_queue()
if q:
return len(q.get("queue_running", [])) + len(q.get("queue_pending", []))
return 0
async def clear_queue(self) -> bool:
"""Clear all pending jobs from the ComfyUI queue."""
try:
url = f"{self.protocol}://{self.server_address}/queue"
async with self.session.post(
url,
json={"clear": True},
headers={"Content-Type": "application/json"},
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
return resp.status in (200, 204)
except Exception as exc:
self.logger.warning("Failed to clear comfy queue: %s", exc)
return False
async def check_connection(self) -> bool:
"""Return True if the ComfyUI server is reachable."""
try:
url = f"{self.protocol}://{self.server_address}/system_stats"
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
return resp.status == 200
except Exception:
return False
async def get_models(self, model_type: str = "checkpoints") -> List[str]:
"""
Fetch available model names from ComfyUI.
Parameters
----------
model_type : str
One of ``"checkpoints"``, ``"loras"``, etc.
"""
try:
url = f"{self.protocol}://{self.server_address}/object_info"
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
resp.raise_for_status()
info = await resp.json()
if model_type == "checkpoints":
node = info.get("CheckpointLoaderSimple", {})
return node.get("input", {}).get("required", {}).get("ckpt_name", [None])[0] or []
elif model_type == "loras":
node = info.get("LoraLoader", {})
return node.get("input", {}).get("required", {}).get("lora_name", [None])[0] or []
return []
except Exception as exc:
self.logger.warning("Failed to fetch models (%s): %s", model_type, exc)
return []
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
async def close(self) -> None:
"""Close the underlying aiohttp session."""
if self._session and not self._session.closed:
await self._session.close()
async def __aenter__(self) -> "ComfyClient":
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()