Initial commit — ComfyUI Discord bot + web UI
Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI (React/TS/Vite/Tailwind), ComfyUI integration, generation history DB, preset manager, workflow inspector, and all supporting modules. Excluded from tracking: .env, invite_tokens.json, *.db (SQLite), current-workflow-changes.json, user_settings/, presets/, logs/, web-static/ (build output), frontend/node_modules/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
604
comfy_client.py
Normal file
604
comfy_client.py
Normal file
@@ -0,0 +1,604 @@
|
||||
"""
|
||||
comfy_client.py
|
||||
================
|
||||
|
||||
Asynchronous client for the ComfyUI API.
|
||||
|
||||
Wraps ComfyUI's REST and WebSocket endpoints. Workflow template injection
|
||||
is now handled by :class:`~workflow_inspector.WorkflowInspector`, so this
|
||||
class only needs to:
|
||||
|
||||
1. Accept a workflow template (delegated to WorkflowManager).
|
||||
2. Accept runtime overrides (delegated to WorkflowStateManager).
|
||||
3. Build the final workflow via inspector.inject_overrides().
|
||||
4. Queue it to ComfyUI, wait for completion via WebSocket, fetch outputs.
|
||||
|
||||
A ``{prompt_id: callback}`` map is maintained for future WebSocket
|
||||
broadcasting (web UI phase). Discord commands still use the synchronous
|
||||
await-and-return model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import websockets
|
||||
|
||||
from workflow_inspector import WorkflowInspector
|
||||
|
||||
|
||||
class ComfyClient:
|
||||
"""
|
||||
Asynchronous ComfyUI client.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
server_address : str
|
||||
``hostname:port`` of the ComfyUI server.
|
||||
workflow_manager : WorkflowManager
|
||||
Template storage (injected).
|
||||
state_manager : WorkflowStateManager
|
||||
Runtime overrides (injected).
|
||||
logger : Optional[logging.Logger]
|
||||
Logger for debug/info messages.
|
||||
history_limit : int
|
||||
Max recent generations to keep in the in-memory deque.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_address: str,
|
||||
workflow_manager,
|
||||
state_manager,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
*,
|
||||
history_limit: int = 10,
|
||||
output_path: Optional[str] = None,
|
||||
) -> None:
|
||||
self.server_address = server_address.strip().rstrip("/")
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
self.protocol = "http"
|
||||
self.ws_protocol = "ws"
|
||||
|
||||
self.workflow_manager = workflow_manager
|
||||
self.state_manager = state_manager
|
||||
self.inspector = WorkflowInspector()
|
||||
self.output_path = output_path
|
||||
|
||||
# prompt_id → asyncio.Future for web-UI broadcast (Phase 4)
|
||||
self._pending_callbacks: Dict[str, Callable] = {}
|
||||
|
||||
from collections import deque
|
||||
self._history = deque(maxlen=history_limit)
|
||||
|
||||
self.last_prompt_id: Optional[str] = None
|
||||
self.last_seed: Optional[int] = None
|
||||
self.total_generated: int = 0
|
||||
|
||||
self.logger = logger if logger else logging.getLogger(__name__)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
"""Lazily create and return an aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Low-level REST helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _queue_prompt(
|
||||
self,
|
||||
prompt: dict[str, Any],
|
||||
prompt_id: str,
|
||||
ws_client_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Submit a workflow to the ComfyUI queue."""
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"client_id": ws_client_id if ws_client_id is not None else self.client_id,
|
||||
"prompt_id": prompt_id,
|
||||
}
|
||||
url = f"{self.protocol}://{self.server_address}/prompt"
|
||||
async with self.session.post(url, json=payload,
|
||||
headers={"Content-Type": "application/json"}) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
async def _wait_for_execution(
|
||||
self,
|
||||
prompt_id: str,
|
||||
on_progress: Optional[Callable[[str, str], None]] = None,
|
||||
ws_client_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Wait for a queued prompt to finish executing via WebSocket.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prompt_id : str
|
||||
The prompt to wait for.
|
||||
on_progress : Optional[Callable[[str, str], None]]
|
||||
Called with ``(node_id, prompt_id)`` for each ``node_executing``
|
||||
event. Pass ``None`` for Discord commands (no web broadcast
|
||||
needed).
|
||||
"""
|
||||
client_id = ws_client_id if ws_client_id is not None else self.client_id
|
||||
ws_url = (
|
||||
f"{self.ws_protocol}://{self.server_address}/ws"
|
||||
f"?clientId={client_id}"
|
||||
)
|
||||
async with websockets.connect(ws_url) as ws:
|
||||
try:
|
||||
while True:
|
||||
out = await ws.recv()
|
||||
if not isinstance(out, str):
|
||||
continue
|
||||
message = json.loads(out)
|
||||
mtype = message.get("type")
|
||||
|
||||
if mtype == "executing":
|
||||
data = message["data"]
|
||||
node = data.get("node")
|
||||
if node:
|
||||
self.logger.debug("Executing node: %s", node)
|
||||
if on_progress and data.get("prompt_id") == prompt_id:
|
||||
try:
|
||||
on_progress(node, prompt_id)
|
||||
except Exception:
|
||||
pass
|
||||
if data["node"] is None and data.get("prompt_id") == prompt_id:
|
||||
self.logger.info("Execution complete for prompt %s", prompt_id)
|
||||
break
|
||||
|
||||
elif mtype == "execution_success":
|
||||
if message.get("data", {}).get("prompt_id") == prompt_id:
|
||||
self.logger.info("execution_success for prompt %s", prompt_id)
|
||||
break
|
||||
|
||||
elif mtype == "execution_error":
|
||||
if message.get("data", {}).get("prompt_id") == prompt_id:
|
||||
error = message.get("data", {}).get("exception_message", "unknown error")
|
||||
raise RuntimeError(f"ComfyUI execution error: {error}")
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.error("Error during execution wait: %s", exc)
|
||||
raise
|
||||
|
||||
async def _get_history(self, prompt_id: str) -> dict[str, Any]:
|
||||
"""Retrieve execution history for a given prompt id."""
|
||||
url = f"{self.protocol}://{self.server_address}/history/{prompt_id}"
|
||||
async with self.session.get(url) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
async def _download_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
|
||||
"""Download an image from ComfyUI and return raw bytes."""
|
||||
url = f"{self.protocol}://{self.server_address}/view"
|
||||
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
async with self.session.get(url, params=params) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.read()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core generation pipeline
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _general_generate(
|
||||
self,
|
||||
workflow: dict[str, Any],
|
||||
prompt_id: str,
|
||||
on_progress: Optional[Callable[[str, str], None]] = None,
|
||||
) -> tuple[List[bytes], List[dict[str, Any]]]:
|
||||
"""
|
||||
Queue a workflow, wait for it to execute, then collect outputs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[List[bytes], List[dict]]
|
||||
``(images, videos)`` — images as raw bytes, videos as info dicts.
|
||||
"""
|
||||
ws_client_id = str(uuid.uuid4())
|
||||
await self._queue_prompt(workflow, prompt_id, ws_client_id)
|
||||
try:
|
||||
await self._wait_for_execution(prompt_id, on_progress=on_progress, ws_client_id=ws_client_id)
|
||||
except Exception:
|
||||
self.logger.error("Execution failed for prompt %s", prompt_id)
|
||||
return [], []
|
||||
|
||||
history = await self._get_history(prompt_id)
|
||||
if not history:
|
||||
self.logger.warning("No history for prompt %s", prompt_id)
|
||||
return [], []
|
||||
|
||||
images: List[bytes] = []
|
||||
videos: List[dict[str, Any]] = []
|
||||
|
||||
for node_output in history.get(prompt_id, {}).get("outputs", {}).values():
|
||||
for image_info in node_output.get("images", []):
|
||||
name = image_info["filename"]
|
||||
if name.rsplit(".", 1)[-1].lower() in {"mp4", "webm", "avi"}:
|
||||
videos.append({
|
||||
"video_name": name,
|
||||
"video_subfolder": image_info.get("subfolder", ""),
|
||||
"video_type": image_info.get("type", "output"),
|
||||
})
|
||||
else:
|
||||
data = await self._download_image(
|
||||
name, image_info["subfolder"], image_info["type"]
|
||||
)
|
||||
images.append(data)
|
||||
|
||||
return images, videos
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# DB persistence helper
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _record_to_db(
|
||||
self,
|
||||
prompt_id: str,
|
||||
source: str,
|
||||
user_label: Optional[str],
|
||||
overrides: Dict[str, Any],
|
||||
seed: Optional[int],
|
||||
images: List[bytes],
|
||||
videos: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Persist generation metadata and file blobs to SQLite. Never raises."""
|
||||
try:
|
||||
import generation_db
|
||||
from pathlib import Path as _Path
|
||||
gen_id = generation_db.record_generation(
|
||||
prompt_id, source, user_label, overrides, seed
|
||||
)
|
||||
for i, img_data in enumerate(images):
|
||||
generation_db.record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||
if videos and self.output_path:
|
||||
for vid in videos:
|
||||
vname = vid.get("video_name", "")
|
||||
vsub = vid.get("video_subfolder", "")
|
||||
vpath = (
|
||||
_Path(self.output_path) / vsub / vname
|
||||
if vsub
|
||||
else _Path(self.output_path) / vname
|
||||
)
|
||||
try:
|
||||
generation_db.record_file(gen_id, vname, vpath.read_bytes())
|
||||
except OSError as exc:
|
||||
self.logger.warning(
|
||||
"Could not read video for DB storage: %s: %s", vpath, exc
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.warning("Failed to record generation to DB: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public generation API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
on_progress: Optional[Callable[[str, str], None]] = None,
|
||||
*,
|
||||
source: str = "discord",
|
||||
user_label: Optional[str] = None,
|
||||
) -> tuple[List[bytes], str]:
|
||||
"""
|
||||
Generate images using the current workflow template with a text prompt.
|
||||
|
||||
Injects *prompt* (and optionally *negative_prompt*) via the inspector,
|
||||
plus any currently pinned seed from the state manager. All other
|
||||
overrides in the state manager are **not** applied here — use
|
||||
:meth:`generate_image_with_workflow` for the full override set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prompt : str
|
||||
Positive prompt text.
|
||||
negative_prompt : Optional[str]
|
||||
Negative prompt text (optional).
|
||||
on_progress : Optional[Callable]
|
||||
Called with ``(node_id, prompt_id)`` for each executing node.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[List[bytes], str]
|
||||
``(images, prompt_id)``
|
||||
"""
|
||||
template = self.workflow_manager.get_workflow_template()
|
||||
if not template:
|
||||
self.logger.warning("No workflow template set; cannot generate.")
|
||||
return [], ""
|
||||
|
||||
overrides: Dict[str, Any] = {"prompt": prompt}
|
||||
if negative_prompt is not None:
|
||||
overrides["negative_prompt"] = negative_prompt
|
||||
# Respect pinned seed from state manager
|
||||
seed_pin = self.state_manager.get_seed()
|
||||
if seed_pin is not None:
|
||||
overrides["seed"] = seed_pin
|
||||
|
||||
workflow, applied = self.inspector.inject_overrides(template, overrides)
|
||||
seed_used = applied.get("seed")
|
||||
self.last_seed = seed_used
|
||||
|
||||
prompt_id = str(uuid.uuid4())
|
||||
images, _videos = await self._general_generate(workflow, prompt_id, on_progress)
|
||||
|
||||
self.last_prompt_id = prompt_id
|
||||
self.total_generated += 1
|
||||
self._history.append({
|
||||
"prompt_id": prompt_id,
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"seed": seed_used,
|
||||
})
|
||||
self._record_to_db(
|
||||
prompt_id, source, user_label,
|
||||
{"prompt": prompt, "negative_prompt": negative_prompt},
|
||||
seed_used, images, [],
|
||||
)
|
||||
return images, prompt_id
|
||||
|
||||
async def generate_image_with_workflow(
|
||||
self,
|
||||
on_progress: Optional[Callable[[str, str], None]] = None,
|
||||
*,
|
||||
source: str = "discord",
|
||||
user_label: Optional[str] = None,
|
||||
) -> tuple[List[bytes], List[dict[str, Any]], str]:
|
||||
"""
|
||||
Generate images/videos from the current workflow applying ALL
|
||||
overrides stored in the state manager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[List[bytes], List[dict], str]
|
||||
``(images, videos, prompt_id)``
|
||||
"""
|
||||
template = self.workflow_manager.get_workflow_template()
|
||||
prompt_id = str(uuid.uuid4())
|
||||
if not template:
|
||||
self.logger.error("No workflow template set")
|
||||
return [], [], prompt_id
|
||||
|
||||
overrides = self.state_manager.get_overrides()
|
||||
workflow, applied = self.inspector.inject_overrides(template, overrides)
|
||||
seed_used = applied.get("seed")
|
||||
self.last_seed = seed_used
|
||||
|
||||
images, videos = await self._general_generate(workflow, prompt_id, on_progress)
|
||||
|
||||
self.last_prompt_id = prompt_id
|
||||
self.total_generated += 1
|
||||
prompt_str = overrides.get("prompt") or ""
|
||||
neg_str = overrides.get("negative_prompt") or ""
|
||||
self._history.append({
|
||||
"prompt_id": prompt_id,
|
||||
"prompt": (prompt_str[:10] + "…") if len(prompt_str) > 10 else prompt_str or None,
|
||||
"negative_prompt": (neg_str[:10] + "…") if len(neg_str) > 10 else neg_str or None,
|
||||
"seed": seed_used,
|
||||
})
|
||||
self._record_to_db(prompt_id, source, user_label, overrides, seed_used, images, videos)
|
||||
return images, videos, prompt_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Workflow template management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_workflow(self, workflow: dict[str, Any]) -> None:
|
||||
"""Set the workflow template and clear all state overrides."""
|
||||
self.workflow_manager.set_workflow_template(workflow)
|
||||
self.state_manager.clear_overrides()
|
||||
|
||||
def load_workflow_from_file(self, path: str) -> None:
|
||||
"""
|
||||
Load a workflow template from a JSON file.
|
||||
|
||||
Also clears state overrides and records the filename in the state
|
||||
manager for auto-load on restart.
|
||||
"""
|
||||
import json as _json
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
workflow = _json.load(f)
|
||||
self.workflow_manager.set_workflow_template(workflow)
|
||||
self.state_manager.clear_overrides()
|
||||
from pathlib import Path
|
||||
self.state_manager.set_last_workflow_file(Path(path).name)
|
||||
|
||||
def get_workflow_template(self) -> Optional[dict[str, Any]]:
|
||||
"""Return the current workflow template (or None)."""
|
||||
return self.workflow_manager.get_workflow_template()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State management convenience wrappers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_workflow_current_changes(self) -> dict[str, Any]:
|
||||
"""Return all current overrides (backward-compat)."""
|
||||
return self.state_manager.get_changes()
|
||||
|
||||
def set_workflow_current_changes(self, changes: dict[str, Any]) -> None:
|
||||
"""Merge override changes (backward-compat)."""
|
||||
self.state_manager.set_changes(changes, merge=True)
|
||||
|
||||
def set_workflow_current_prompt(self, prompt: str) -> None:
|
||||
self.state_manager.set_prompt(prompt)
|
||||
|
||||
def set_workflow_current_negative_prompt(self, negative_prompt: str) -> None:
|
||||
self.state_manager.set_negative_prompt(negative_prompt)
|
||||
|
||||
def set_workflow_current_input_image(self, input_image: str) -> None:
|
||||
self.state_manager.set_input_image(input_image)
|
||||
|
||||
def get_current_workflow_prompt(self) -> Optional[str]:
|
||||
return self.state_manager.get_prompt()
|
||||
|
||||
def get_current_workflow_negative_prompt(self) -> Optional[str]:
|
||||
return self.state_manager.get_negative_prompt()
|
||||
|
||||
def get_current_workflow_input_image(self) -> Optional[str]:
|
||||
return self.state_manager.get_input_image()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Image upload
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def upload_image(
|
||||
self,
|
||||
data: bytes,
|
||||
filename: str,
|
||||
*,
|
||||
image_type: str = "input",
|
||||
overwrite: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Upload an image to ComfyUI via the /upload/image endpoint."""
|
||||
url = f"{self.protocol}://{self.server_address}/upload/image"
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("image", data, filename=filename,
|
||||
content_type="application/octet-stream")
|
||||
form.add_field("type", image_type)
|
||||
form.add_field("overwrite", str(overwrite).lower())
|
||||
async with self.session.post(url, data=form) as resp:
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
return await resp.json()
|
||||
except aiohttp.ContentTypeError:
|
||||
return {"status": await resp.text()}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# History
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_history(self) -> List[dict]:
|
||||
"""Return a list of recently generated prompt records (from DB)."""
|
||||
try:
|
||||
from generation_db import get_history as db_get_history
|
||||
return db_get_history(limit=self._history.maxlen or 50)
|
||||
except Exception:
|
||||
return list(self._history)
|
||||
|
||||
async def fetch_history_images(self, prompt_id: str) -> List[bytes]:
|
||||
"""Re-download images for a previously generated prompt."""
|
||||
history = await self._get_history(prompt_id)
|
||||
images: List[bytes] = []
|
||||
for node_output in history.get(prompt_id, {}).get("outputs", {}).values():
|
||||
for image_info in node_output.get("images", []):
|
||||
data = await self._download_image(
|
||||
image_info["filename"],
|
||||
image_info["subfolder"],
|
||||
image_info["type"],
|
||||
)
|
||||
images.append(data)
|
||||
return images
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Server info / queue
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_system_stats(self) -> Optional[dict[str, Any]]:
|
||||
"""Fetch ComfyUI system stats (/system_stats)."""
|
||||
try:
|
||||
url = f"{self.protocol}://{self.server_address}/system_stats"
|
||||
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
except Exception as exc:
|
||||
self.logger.warning("Failed to fetch system stats: %s", exc)
|
||||
return None
|
||||
|
||||
async def get_comfy_queue(self) -> Optional[dict[str, Any]]:
|
||||
"""Fetch the ComfyUI queue (/queue)."""
|
||||
try:
|
||||
url = f"{self.protocol}://{self.server_address}/queue"
|
||||
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
except Exception as exc:
|
||||
self.logger.warning("Failed to fetch comfy queue: %s", exc)
|
||||
return None
|
||||
|
||||
async def get_queue_depth(self) -> int:
|
||||
"""Return the total number of pending + running jobs in ComfyUI."""
|
||||
q = await self.get_comfy_queue()
|
||||
if q:
|
||||
return len(q.get("queue_running", [])) + len(q.get("queue_pending", []))
|
||||
return 0
|
||||
|
||||
async def clear_queue(self) -> bool:
|
||||
"""Clear all pending jobs from the ComfyUI queue."""
|
||||
try:
|
||||
url = f"{self.protocol}://{self.server_address}/queue"
|
||||
async with self.session.post(
|
||||
url,
|
||||
json={"clear": True},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as resp:
|
||||
return resp.status in (200, 204)
|
||||
except Exception as exc:
|
||||
self.logger.warning("Failed to clear comfy queue: %s", exc)
|
||||
return False
|
||||
|
||||
async def check_connection(self) -> bool:
|
||||
"""Return True if the ComfyUI server is reachable."""
|
||||
try:
|
||||
url = f"{self.protocol}://{self.server_address}/system_stats"
|
||||
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||
return resp.status == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_models(self, model_type: str = "checkpoints") -> List[str]:
|
||||
"""
|
||||
Fetch available model names from ComfyUI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_type : str
|
||||
One of ``"checkpoints"``, ``"loras"``, etc.
|
||||
"""
|
||||
try:
|
||||
url = f"{self.protocol}://{self.server_address}/object_info"
|
||||
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||
resp.raise_for_status()
|
||||
info = await resp.json()
|
||||
if model_type == "checkpoints":
|
||||
node = info.get("CheckpointLoaderSimple", {})
|
||||
return node.get("input", {}).get("required", {}).get("ckpt_name", [None])[0] or []
|
||||
elif model_type == "loras":
|
||||
node = info.get("LoraLoader", {})
|
||||
return node.get("input", {}).get("required", {}).get("lora_name", [None])[0] or []
|
||||
return []
|
||||
except Exception as exc:
|
||||
self.logger.warning("Failed to fetch models (%s): %s", model_type, exc)
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying aiohttp session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self) -> "ComfyClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.close()
|
||||
Reference in New Issue
Block a user