Files
comfy-discord-web/workflow_inspector.py
Khoa (Revenovich) Tran Gia 1ed3c9ec4b Initial commit — ComfyUI Discord bot + web UI
Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI
(React/TS/Vite/Tailwind), ComfyUI integration, generation history DB,
preset manager, workflow inspector, and all supporting modules.

Excluded from tracking: .env, invite_tokens.json, *.db (SQLite),
current-workflow-changes.json, user_settings/, presets/, logs/,
web-static/ (build output), frontend/node_modules/.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 09:55:48 +07:00

398 lines
14 KiB
Python

"""
workflow_inspector.py
=====================
Dynamic workflow node inspection and injection for the Discord ComfyUI bot.
Replaces the hardcoded node-finding methods in WorkflowManager with a
general-purpose inspector that works with any ComfyUI workflow. The
inspector discovers injectable inputs at load time by walking the workflow
JSON, classifying each scalar input by class_type + input_name, and
assigning stable human-readable keys (e.g. ``"prompt"``, ``"seed"``,
``"input_image"``) that can be stored in WorkflowStateManager and used by
both Discord commands and the web UI.
Key-assignment rules
--------------------
- CLIPTextEncode / text → ``"prompt"`` / ``"negative_prompt"`` / ``"text_{node_id}"``
- LoadImage / image → ``"input_image"`` (first), ``{title_slug}`` (subsequent)
- KSampler* / seed → ``"seed"``
- KSampler / steps|cfg|sampler_name|scheduler|denoise → same name
- EmptyLatentImage / width|height → ``"width"`` / ``"height"``
- CheckpointLoaderSimple / ckpt_name → ``"checkpoint"``
- LoraLoader / lora_name → ``"lora_{node_id}"``
- Other scalars → ``"{class_slug}_{node_id}_{input_name}"``
"""
from __future__ import annotations
import copy
import random
import re
from dataclasses import dataclass
from typing import Any, Optional
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _slugify(text: str) -> str:
"""Convert arbitrary text to a lowercase underscore key."""
text = text.lower()
text = re.sub(r"[^a-z0-9]+", "_", text)
return text.strip("_") or "node"
def _numeric_sort_key(node_id: str):
"""Sort node IDs numerically where possible."""
try:
return (0, int(node_id))
except (ValueError, TypeError):
return (1, node_id)
# ---------------------------------------------------------------------------
# NodeInput dataclass
# ---------------------------------------------------------------------------
@dataclass
class NodeInput:
"""
Descriptor for a single injectable input within a ComfyUI workflow.
Attributes
----------
node_id : str
The node's key in the workflow dict.
node_class : str
``class_type`` of the node (e.g. ``"KSampler"``).
node_title : str
Value of ``_meta.title`` for the node (may be empty).
input_name : str
The input field name within the node's ``inputs`` dict.
input_type : str
Semantic type: ``"text"``, ``"seed"``, ``"image"``, ``"integer"``,
``"float"``, ``"string"``, ``"checkpoint"``, ``"lora"``.
current_value : Any
The value currently stored in the workflow template.
label : str
Human-readable display label (``"NodeTitle / input_name"``).
key : str
Stable short key used by Discord commands and state storage.
is_common : bool
True for prompts, seeds, and image inputs (shown prominently in UI).
"""
node_id: str
node_class: str
node_title: str
input_name: str
input_type: str
current_value: Any
label: str
key: str
is_common: bool
# ---------------------------------------------------------------------------
# WorkflowInspector
# ---------------------------------------------------------------------------
class WorkflowInspector:
"""
Inspects and modifies ComfyUI workflow JSON.
Usage::
inspector = WorkflowInspector()
inputs = inspector.inspect(workflow)
modified, applied = inspector.inject_overrides(workflow, {"prompt": "a cat"})
"""
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def inspect(self, workflow: dict[str, Any]) -> list[NodeInput]:
"""
Walk a workflow and return all injectable scalar inputs.
Parameters
----------
workflow : dict
ComfyUI workflow in API format (node_id → node_dict).
Returns
-------
list[NodeInput]
All injectable inputs, common ones first, advanced sorted by key.
"""
inputs: list[NodeInput] = []
load_image_count = 0 # for unique LoadImage key assignment
for node_id, node in sorted(workflow.items(), key=lambda kv: _numeric_sort_key(kv[0])):
if not isinstance(node, dict):
continue
class_type: str = node.get("class_type", "")
title: str = node.get("_meta", {}).get("title", "") or ""
node_inputs = node.get("inputs", {})
if not isinstance(node_inputs, dict):
continue
for input_name, value in node_inputs.items():
# Skip node-reference inputs (they are lists like [node_id, output_slot])
if isinstance(value, list):
continue
# Skip None values — no useful type info
if value is None:
continue
ni = self._classify_input(
node_id=node_id,
class_type=class_type,
title=title,
input_name=input_name,
value=value,
load_image_count=load_image_count,
)
if ni is not None:
if ni.input_type == "image":
load_image_count += 1
inputs.append(ni)
# Sort: common first, then advanced (both groups sorted by key for stability)
inputs.sort(key=lambda x: (0 if x.is_common else 1, x.key))
return inputs
def inject_overrides(
self,
workflow: dict[str, Any],
overrides: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Deep-copy a workflow and inject override values.
Seeds that are absent from *overrides* or set to ``-1`` are
auto-randomized. All other injectable keys found via
:meth:`inspect` are updated if present in *overrides*.
Parameters
----------
workflow : dict
The workflow template (not mutated).
overrides : dict
Mapping of ``NodeInput.key → value`` to inject.
Returns
-------
tuple[dict, dict]
``(modified_workflow, applied_values)`` where *applied_values*
maps each key that was actually written (including auto-generated
seeds) to the value that was used.
"""
wf = copy.deepcopy(workflow)
applied: dict[str, Any] = {}
# Inspect the deep copy to build key → [(node_id, input_name), …]
inputs = self.inspect(wf)
# Group targets by key
key_targets: dict[str, list[tuple[str, str]]] = {}
key_itype: dict[str, str] = {}
for ni in inputs:
if ni.key not in key_targets:
key_targets[ni.key] = []
key_itype[ni.key] = ni.input_type
key_targets[ni.key].append((ni.node_id, ni.input_name))
for key, targets in key_targets.items():
itype = key_itype[key]
override_val = overrides.get(key)
if itype == "seed":
# -1 sentinel or absent → auto-randomize
if override_val is None or override_val == -1:
seed = random.randint(0, 2 ** 32 - 1)
else:
seed = int(override_val)
for node_id, input_name in targets:
wf[node_id]["inputs"][input_name] = seed
applied[key] = seed
elif override_val is not None:
for node_id, input_name in targets:
wf[node_id]["inputs"][input_name] = override_val
applied[key] = override_val
return wf, applied
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _classify_input(
self,
*,
node_id: str,
class_type: str,
title: str,
input_name: str,
value: Any,
load_image_count: int,
) -> Optional[NodeInput]:
"""
Return a NodeInput for one field, or None to skip it entirely.
"""
title_display = title or class_type
# ---- CLIPTextEncode → positive/negative prompt ----
if class_type == "CLIPTextEncode" and input_name == "text":
t = title.lower()
if "positive" in t:
key = "prompt"
elif "negative" in t:
key = "negative_prompt"
else:
key = f"text_{node_id}"
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="text",
current_value=value,
label=f"{title_display} / {input_name}",
key=key,
is_common=True,
)
# ---- LoadImage → input image ----
if class_type == "LoadImage" and input_name == "image":
if load_image_count == 0:
key = "input_image"
else:
slug = _slugify(title) if title else f"image_{node_id}"
key = slug or f"image_{node_id}"
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="image",
current_value=value,
label=f"{title_display} / {input_name}",
key=key,
is_common=True,
)
# ---- KSampler / KSamplerAdvanced ----
if class_type in ("KSampler", "KSamplerAdvanced"):
if input_name in ("seed", "noise_seed"):
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="seed",
current_value=value,
label=f"{title_display} / seed",
key="seed",
is_common=True,
)
_ksampler_advanced = {
"steps": ("integer", "steps", False),
"cfg": ("float", "cfg", False),
"sampler_name": ("string", "sampler_name", False),
"scheduler": ("string", "scheduler", False),
"denoise": ("float", "denoise", False),
}
if input_name in _ksampler_advanced:
itype, key, is_common = _ksampler_advanced[input_name]
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type=itype,
current_value=value,
label=f"{title_display} / {input_name}",
key=key,
is_common=is_common,
)
# ---- EmptyLatentImage ----
if class_type == "EmptyLatentImage" and input_name in ("width", "height"):
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="integer",
current_value=value,
label=f"{title_display} / {input_name}",
key=input_name,
is_common=False,
)
# ---- CheckpointLoaderSimple ----
if class_type == "CheckpointLoaderSimple" and input_name == "ckpt_name":
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="checkpoint",
current_value=value,
label=f"{title_display} / checkpoint",
key="checkpoint",
is_common=False,
)
# ---- LoraLoader ----
if class_type == "LoraLoader" and input_name == "lora_name":
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="lora",
current_value=value,
label=f"{title_display} / lora",
key=f"lora_{node_id}",
is_common=False,
)
# ---- Skip non-scalar or already-handled classes ----
_handled_classes = {
"CLIPTextEncode", "LoadImage", "KSampler", "KSamplerAdvanced",
"EmptyLatentImage", "CheckpointLoaderSimple", "LoraLoader",
}
if class_type in _handled_classes:
return None # unrecognised field for a known class → skip
# ---- Generic scalar fallback ----
if isinstance(value, bool):
return None # booleans aren't useful override targets
if isinstance(value, int):
itype = "integer"
elif isinstance(value, float):
itype = "float"
elif isinstance(value, str):
itype = "string"
else:
return None # dicts, etc. — skip
key = f"{_slugify(class_type)}_{node_id}_{input_name}"
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type=itype,
current_value=value,
label=f"{title_display} / {input_name}",
key=key,
is_common=False,
)