""" workflow_state.py ================= Workflow state management for the Discord ComfyUI bot. This module provides a WorkflowStateManager class that stores runtime overrides for workflow parameters (prompt, negative_prompt, input_image, seed, steps, cfg, …) in a generic key-value dict. Any NodeInput.key produced by WorkflowInspector can be set as an override. The old fixed fields (prompt, negative_prompt, input_image, seed) are preserved as convenience wrappers for backward compatibility with existing Discord commands and status_monitor. State file format (current-workflow-changes.json):: { "overrides": { "prompt": "a beautiful landscape", "seed": 42, ... }, "last_workflow_file": "my_workflow.json" } The old flat-key format is migrated automatically on first load. """ from __future__ import annotations import json import logging from pathlib import Path from typing import Any, Dict, Optional logger = logging.getLogger(__name__) class WorkflowStateManager: """ Manages runtime workflow overrides in memory with optional persistence. Override keys correspond to ``NodeInput.key`` values discovered by :class:`~workflow_inspector.WorkflowInspector`. Common well-known keys are ``"prompt"``, ``"negative_prompt"``, ``"input_image"``, ``"seed"``. Parameters ---------- state_file : Optional[str] Path to a JSON file for persisting overrides. Loaded on init if the file exists; auto-saved on every change. """ def __init__(self, state_file: Optional[str] = None) -> None: self._overrides: Dict[str, Any] = {} self._last_workflow_file: Optional[str] = None self._state_file = state_file if self._state_file: self._load_from_file() # ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------ def _load_from_file(self) -> None: """Load state from the configured JSON file if it exists.""" if not self._state_file: return state_path = Path(self._state_file) if not state_path.exists(): logger.debug("State file %s not found, using empty state", self._state_file) return try: with open(state_path, "r", encoding="utf-8") as f: data = json.load(f) # New format: {"overrides": {...}, "last_workflow_file": ...} if "overrides" in data and isinstance(data["overrides"], dict): self._overrides = data["overrides"] self._last_workflow_file = data.get("last_workflow_file") else: # Migrate old flat format: {"prompt": ..., "negative_prompt": ..., ...} self._overrides = { k: v for k, v in data.items() if v is not None and k not in ("last_workflow_file",) } self._last_workflow_file = data.get("last_workflow_file") logger.info("Loaded workflow state from %s", self._state_file) except Exception as exc: logger.warning("Failed to load state from %s: %s", self._state_file, exc) def save_to_file(self) -> None: """ Persist current overrides and last_workflow_file to the state JSON file. Raises ------ RuntimeError If no state file was configured. """ if not self._state_file: raise RuntimeError("Cannot save state: no state file configured") try: data = { "overrides": self._overrides, "last_workflow_file": self._last_workflow_file, } with open(self._state_file, "w", encoding="utf-8") as f: json.dump(data, f, indent=4) logger.debug("Saved workflow state to %s", self._state_file) except Exception as exc: logger.error("Failed to save state to %s: %s", self._state_file, exc) raise def _autosave(self) -> None: """Save to file silently if a state file is configured.""" if self._state_file: try: self.save_to_file() except Exception: pass # already logged inside save_to_file # ------------------------------------------------------------------ # Generic override API # ------------------------------------------------------------------ def get_overrides(self) -> Dict[str, Any]: """Return a shallow copy of all current overrides.""" return self._overrides.copy() def set_override(self, key: str, value: Any) -> None: """Set a single override key and auto-save.""" self._overrides[key] = value self._autosave() def delete_override(self, key: str) -> None: """Remove a single override key (no-op if absent) and auto-save.""" self._overrides.pop(key, None) self._autosave() def clear_overrides(self) -> None: """Remove all override keys and auto-save.""" self._overrides = {} self._autosave() # ------------------------------------------------------------------ # Last-workflow-file tracking # ------------------------------------------------------------------ def get_last_workflow_file(self) -> Optional[str]: """Return the last loaded workflow filename (for auto-load on restart).""" return self._last_workflow_file def set_last_workflow_file(self, filename: Optional[str]) -> None: """Record the last loaded workflow filename and auto-save.""" self._last_workflow_file = filename self._autosave() # ------------------------------------------------------------------ # Backward-compat: old get_changes / set_changes API # ------------------------------------------------------------------ def get_changes(self) -> Dict[str, Any]: """ Alias for :meth:`get_overrides` retained for backward compatibility. Returns a dict that always has ``prompt``, ``negative_prompt``, ``input_image``, and ``seed`` keys (value is ``None`` when unset) so existing callers that rely on these specific keys still work. """ base: Dict[str, Any] = { "prompt": None, "negative_prompt": None, "input_image": None, "seed": None, } base.update(self._overrides) return base def set_changes(self, changes: Dict[str, Any], merge: bool = True) -> None: """ Set multiple overrides at once. Parameters ---------- changes : dict Key-value pairs to apply. merge : bool If True (default), merge with existing overrides. If False, replace all overrides with these values (``None`` values are excluded). """ if merge: for k, v in changes.items(): if v is not None: self._overrides[k] = v else: self._overrides = {k: v for k, v in changes.items() if v is not None} self._autosave() # ------------------------------------------------------------------ # Convenience setters / getters for well-known keys # ------------------------------------------------------------------ def set_prompt(self, prompt: str) -> None: """Set the positive prompt override.""" self.set_override("prompt", prompt) def get_prompt(self) -> Optional[str]: """Return the positive prompt override, or None if not set.""" return self._overrides.get("prompt") def set_negative_prompt(self, negative_prompt: str) -> None: """Set the negative prompt override.""" self.set_override("negative_prompt", negative_prompt) def get_negative_prompt(self) -> Optional[str]: """Return the negative prompt override, or None if not set.""" return self._overrides.get("negative_prompt") def set_input_image(self, input_image: str) -> None: """Set the input image override (filename).""" self.set_override("input_image", input_image) def get_input_image(self) -> Optional[str]: """Return the input image override, or None if not set.""" return self._overrides.get("input_image") def set_seed(self, seed: int) -> None: """Pin a specific seed for deterministic generation.""" self.set_override("seed", seed) def get_seed(self) -> Optional[int]: """Return the pinned seed, or None if randomising each run.""" return self._overrides.get("seed") def clear_seed(self) -> None: """Clear the pinned seed, reverting to random generation.""" self.delete_override("seed") def clear(self) -> None: """Reset all overrides (alias for :meth:`clear_overrides`).""" self.clear_overrides() def __repr__(self) -> str: return ( f"WorkflowStateManager(" f"overrides={self._overrides!r}, " f"last_workflow_file={self._last_workflow_file!r})" )