""" workflow_inspector.py ===================== Dynamic workflow node inspection and injection for the Discord ComfyUI bot. Replaces the hardcoded node-finding methods in WorkflowManager with a general-purpose inspector that works with any ComfyUI workflow. The inspector discovers injectable inputs at load time by walking the workflow JSON, classifying each scalar input by class_type + input_name, and assigning stable human-readable keys (e.g. ``"prompt"``, ``"seed"``, ``"input_image"``) that can be stored in WorkflowStateManager and used by both Discord commands and the web UI. Key-assignment rules -------------------- - CLIPTextEncode / text → ``"prompt"`` / ``"negative_prompt"`` / ``"text_{node_id}"`` - LoadImage / image → ``"input_image"`` (first), ``{title_slug}`` (subsequent) - KSampler* / seed → ``"seed"`` - KSampler / steps|cfg|sampler_name|scheduler|denoise → same name - EmptyLatentImage / width|height → ``"width"`` / ``"height"`` - CheckpointLoaderSimple / ckpt_name → ``"checkpoint"`` - LoraLoader / lora_name → ``"lora_{node_id}"`` - Other scalars → ``"{class_slug}_{node_id}_{input_name}"`` """ from __future__ import annotations import copy import random import re from dataclasses import dataclass from typing import Any, Optional # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _slugify(text: str) -> str: """Convert arbitrary text to a lowercase underscore key.""" text = text.lower() text = re.sub(r"[^a-z0-9]+", "_", text) return text.strip("_") or "node" def _numeric_sort_key(node_id: str): """Sort node IDs numerically where possible.""" try: return (0, int(node_id)) except (ValueError, TypeError): return (1, node_id) # --------------------------------------------------------------------------- # NodeInput dataclass # --------------------------------------------------------------------------- @dataclass class NodeInput: """ Descriptor for a single injectable input within a ComfyUI workflow. Attributes ---------- node_id : str The node's key in the workflow dict. node_class : str ``class_type`` of the node (e.g. ``"KSampler"``). node_title : str Value of ``_meta.title`` for the node (may be empty). input_name : str The input field name within the node's ``inputs`` dict. input_type : str Semantic type: ``"text"``, ``"seed"``, ``"image"``, ``"integer"``, ``"float"``, ``"string"``, ``"checkpoint"``, ``"lora"``. current_value : Any The value currently stored in the workflow template. label : str Human-readable display label (``"NodeTitle / input_name"``). key : str Stable short key used by Discord commands and state storage. is_common : bool True for prompts, seeds, and image inputs (shown prominently in UI). """ node_id: str node_class: str node_title: str input_name: str input_type: str current_value: Any label: str key: str is_common: bool # --------------------------------------------------------------------------- # WorkflowInspector # --------------------------------------------------------------------------- class WorkflowInspector: """ Inspects and modifies ComfyUI workflow JSON. Usage:: inspector = WorkflowInspector() inputs = inspector.inspect(workflow) modified, applied = inspector.inject_overrides(workflow, {"prompt": "a cat"}) """ # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def inspect(self, workflow: dict[str, Any]) -> list[NodeInput]: """ Walk a workflow and return all injectable scalar inputs. Parameters ---------- workflow : dict ComfyUI workflow in API format (node_id → node_dict). Returns ------- list[NodeInput] All injectable inputs, common ones first, advanced sorted by key. """ inputs: list[NodeInput] = [] load_image_count = 0 # for unique LoadImage key assignment for node_id, node in sorted(workflow.items(), key=lambda kv: _numeric_sort_key(kv[0])): if not isinstance(node, dict): continue class_type: str = node.get("class_type", "") title: str = node.get("_meta", {}).get("title", "") or "" node_inputs = node.get("inputs", {}) if not isinstance(node_inputs, dict): continue for input_name, value in node_inputs.items(): # Skip node-reference inputs (they are lists like [node_id, output_slot]) if isinstance(value, list): continue # Skip None values — no useful type info if value is None: continue ni = self._classify_input( node_id=node_id, class_type=class_type, title=title, input_name=input_name, value=value, load_image_count=load_image_count, ) if ni is not None: if ni.input_type == "image": load_image_count += 1 inputs.append(ni) # Sort: common first, then advanced (both groups sorted by key for stability) inputs.sort(key=lambda x: (0 if x.is_common else 1, x.key)) return inputs def inject_overrides( self, workflow: dict[str, Any], overrides: dict[str, Any], ) -> tuple[dict[str, Any], dict[str, Any]]: """ Deep-copy a workflow and inject override values. Seeds that are absent from *overrides* or set to ``-1`` are auto-randomized. All other injectable keys found via :meth:`inspect` are updated if present in *overrides*. Parameters ---------- workflow : dict The workflow template (not mutated). overrides : dict Mapping of ``NodeInput.key → value`` to inject. Returns ------- tuple[dict, dict] ``(modified_workflow, applied_values)`` where *applied_values* maps each key that was actually written (including auto-generated seeds) to the value that was used. """ wf = copy.deepcopy(workflow) applied: dict[str, Any] = {} # Inspect the deep copy to build key → [(node_id, input_name), …] inputs = self.inspect(wf) # Group targets by key key_targets: dict[str, list[tuple[str, str]]] = {} key_itype: dict[str, str] = {} for ni in inputs: if ni.key not in key_targets: key_targets[ni.key] = [] key_itype[ni.key] = ni.input_type key_targets[ni.key].append((ni.node_id, ni.input_name)) for key, targets in key_targets.items(): itype = key_itype[key] override_val = overrides.get(key) if itype == "seed": # -1 sentinel or absent → auto-randomize if override_val is None or override_val == -1: seed = random.randint(0, 2 ** 32 - 1) else: seed = int(override_val) for node_id, input_name in targets: wf[node_id]["inputs"][input_name] = seed applied[key] = seed elif override_val is not None: for node_id, input_name in targets: wf[node_id]["inputs"][input_name] = override_val applied[key] = override_val return wf, applied # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _classify_input( self, *, node_id: str, class_type: str, title: str, input_name: str, value: Any, load_image_count: int, ) -> Optional[NodeInput]: """ Return a NodeInput for one field, or None to skip it entirely. """ title_display = title or class_type # ---- CLIPTextEncode → positive/negative prompt ---- if class_type == "CLIPTextEncode" and input_name == "text": t = title.lower() if "positive" in t: key = "prompt" elif "negative" in t: key = "negative_prompt" else: key = f"text_{node_id}" return NodeInput( node_id=node_id, node_class=class_type, node_title=title, input_name=input_name, input_type="text", current_value=value, label=f"{title_display} / {input_name}", key=key, is_common=True, ) # ---- LoadImage → input image ---- if class_type == "LoadImage" and input_name == "image": if load_image_count == 0: key = "input_image" else: slug = _slugify(title) if title else f"image_{node_id}" key = slug or f"image_{node_id}" return NodeInput( node_id=node_id, node_class=class_type, node_title=title, input_name=input_name, input_type="image", current_value=value, label=f"{title_display} / {input_name}", key=key, is_common=True, ) # ---- KSampler / KSamplerAdvanced ---- if class_type in ("KSampler", "KSamplerAdvanced"): if input_name in ("seed", "noise_seed"): return NodeInput( node_id=node_id, node_class=class_type, node_title=title, input_name=input_name, input_type="seed", current_value=value, label=f"{title_display} / seed", key="seed", is_common=True, ) _ksampler_advanced = { "steps": ("integer", "steps", False), "cfg": ("float", "cfg", False), "sampler_name": ("string", "sampler_name", False), "scheduler": ("string", "scheduler", False), "denoise": ("float", "denoise", False), } if input_name in _ksampler_advanced: itype, key, is_common = _ksampler_advanced[input_name] return NodeInput( node_id=node_id, node_class=class_type, node_title=title, input_name=input_name, input_type=itype, current_value=value, label=f"{title_display} / {input_name}", key=key, is_common=is_common, ) # ---- EmptyLatentImage ---- if class_type == "EmptyLatentImage" and input_name in ("width", "height"): return NodeInput( node_id=node_id, node_class=class_type, node_title=title, input_name=input_name, input_type="integer", current_value=value, label=f"{title_display} / {input_name}", key=input_name, is_common=False, ) # ---- CheckpointLoaderSimple ---- if class_type == "CheckpointLoaderSimple" and input_name == "ckpt_name": return NodeInput( node_id=node_id, node_class=class_type, node_title=title, input_name=input_name, input_type="checkpoint", current_value=value, label=f"{title_display} / checkpoint", key="checkpoint", is_common=False, ) # ---- LoraLoader ---- if class_type == "LoraLoader" and input_name == "lora_name": return NodeInput( node_id=node_id, node_class=class_type, node_title=title, input_name=input_name, input_type="lora", current_value=value, label=f"{title_display} / lora", key=f"lora_{node_id}", is_common=False, ) # ---- Skip non-scalar or already-handled classes ---- _handled_classes = { "CLIPTextEncode", "LoadImage", "KSampler", "KSamplerAdvanced", "EmptyLatentImage", "CheckpointLoaderSimple", "LoraLoader", } if class_type in _handled_classes: return None # unrecognised field for a known class → skip # ---- Generic scalar fallback ---- if isinstance(value, bool): return None # booleans aren't useful override targets if isinstance(value, int): itype = "integer" elif isinstance(value, float): itype = "float" elif isinstance(value, str): itype = "string" else: return None # dicts, etc. — skip key = f"{_slugify(class_type)}_{node_id}_{input_name}" return NodeInput( node_id=node_id, node_class=class_type, node_title=title, input_name=input_name, input_type=itype, current_value=value, label=f"{title_display} / {input_name}", key=key, is_common=False, )