Initial commit — ComfyUI Discord bot + web UI

Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI
(React/TS/Vite/Tailwind), ComfyUI integration, generation history DB,
preset manager, workflow inspector, and all supporting modules.

Excluded from tracking: .env, invite_tokens.json, *.db (SQLite),
current-workflow-changes.json, user_settings/, presets/, logs/,
web-static/ (build output), frontend/node_modules/.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Khoa (Revenovich) Tran Gia
2026-03-02 09:55:48 +07:00
commit 1ed3c9ec4b
82 changed files with 20693 additions and 0 deletions

14
.gitattributes vendored Normal file
View File

@@ -0,0 +1,14 @@
# Normalize line endings to LF in the repository (CRLF on Windows checkout)
* text=auto eol=lf
# Binary files — no line-ending conversion
*.db binary
*.png binary
*.jpg binary
*.jpeg binary
*.gif binary
*.webp binary
*.mp4 binary
*.webm binary
*.zip binary
*.whl binary

81
.gitignore vendored Normal file
View File

@@ -0,0 +1,81 @@
# ── Python ────────────────────────────────────────────────────────────────────
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# ── Virtual environments ───────────────────────────────────────────────────────
venv/
env/
ENV/
.venv/
# ── Secrets / environment ─────────────────────────────────────────────────────
# Contains DISCORD_BOT_TOKEN and other credentials — NEVER commit
.env
# Hashed invite tokens — auth credentials; regenerate via token_store.py CLI
invite_tokens.json
# ── SQLite databases ──────────────────────────────────────────────────────────
# generation_history.db — user generation records
# input_images.db — image BLOBs (can be hundreds of MB)
*.db
# ── Runtime / generated state ─────────────────────────────────────────────────
# Active workflow overrides (prompt, seed, etc.) — machine-local runtime state
current-workflow-changes.json
# Per-user persistent settings (created at runtime under user labels)
user_settings/
# User-created presets — runtime data; not project source
presets/
# NSSM / service log files
logs/
# ── Frontend build artefacts ──────────────────────────────────────────────────
# Regenerate with: cd frontend && npm run build
web-static/
# npm dependencies — restored with: cd frontend && npm install
frontend/node_modules/
# Vite cache
frontend/.vite/
# ── IDE / editor ──────────────────────────────────────────────────────────────
.vscode/
.idea/
*.swp
*.swo
*~
# ── OS ────────────────────────────────────────────────────────────────────────
.DS_Store
Thumbs.db
# ── Syncthing ─────────────────────────────────────────────────────────────────
.stfolder/
.stignore
# ── Claude Code project files ─────────────────────────────────────────────────
# Local conversation history and per-project Claude config — not shared
.claude/

4462
AIO.json Normal file

File diff suppressed because it is too large Load Diff

194
CLAUDE.md Normal file
View File

@@ -0,0 +1,194 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
This is a Discord bot that integrates with ComfyUI to generate AI images and videos. Users interact via Discord commands, which queue generation requests that execute on a ComfyUI server.
## Architecture
### File Structure
The codebase is organized into focused modules for maintainability:
```
the-third-rev/
├── config.py # Configuration and constants
├── job_queue.py # Job queue system (SerialJobQueue)
├── workflow_manager.py # Workflow manipulation logic
├── workflow_state.py # Runtime workflow state management
├── discord_utils.py # Discord helpers and decorators
├── commands/ # Command handlers (organized by functionality)
│ ├── __init__.py # Command registration
│ ├── generation.py # generate, workflow-gen commands
│ ├── workflow.py # workflow-load command
│ ├── upload.py # upload command
│ ├── history.py # history, get-history commands
│ └── workflow_changes.py # get/set workflow changes commands
├── bot.py # Main bot entry point (~150 lines)
└── comfy_client.py # ComfyUI API client (~650 lines)
```
### Core Components
- **bot.py**: Minimal Discord bot entry point. Loads configuration, creates dependencies, and registers commands. No command logic here.
- **comfy_client.py**: Async client wrapping ComfyUI's REST and WebSocket APIs. Dependencies (WorkflowManager, WorkflowStateManager) are injected via constructor.
- **config.py**: Centralized configuration with `BotConfig.from_env()` for loading environment variables and constants.
- **job_queue.py**: `SerialJobQueue` ensuring generation requests execute sequentially, preventing ComfyUI server overload.
- **workflow_manager.py**: `WorkflowManager` class handling workflow template storage and node manipulation (finding/replacing prompts, seeds, etc).
- **workflow_state.py**: `WorkflowStateManager` class managing runtime workflow changes (prompt, negative_prompt, input_image) in memory with optional file persistence.
- **discord_utils.py**: Reusable Discord utilities including `@require_comfy_client` decorator, argument parsing, and the `UploadView` component.
- **commands/**: Command handlers organized by functionality. Each module exports a `setup_*_commands(bot, config)` function.
### Key Architectural Patterns
1. **Dependency Injection**: ComfyClient receives WorkflowManager and WorkflowStateManager as constructor parameters, eliminating tight coupling to file-based state.
2. **Job Queue System**: All generation requests are queued through `SerialJobQueue` in job_queue.py. Jobs execute serially with a worker loop that catches and logs exceptions without crashing the bot.
3. **Workflow System**: The bot uses two modes:
- **Prompt mode**: Simple prompt + negative_prompt (requires workflow template with KSampler node)
- **Workflow mode**: Full workflow JSON with dynamic modifications from WorkflowStateManager
4. **Workflow Modification Flow**:
- Load workflow template via `bot.comfy.set_workflow()` or `bot.comfy.load_workflow_from_file()`
- Runtime changes (prompt, negative_prompt, input_image) stored in WorkflowStateManager
- At generation time, WorkflowManager methods locate nodes by class_type and title metadata, then inject values
- Seeds are randomized automatically via `workflow_manager.find_and_replace_seed()`
5. **Command Registration**: Commands are registered via `commands.register_all_commands(bot, config)` which calls individual `setup_*_commands()` functions from each command module.
6. **Configuration Management**: All configuration loaded via `BotConfig.from_env()` in config.py. Constants (command prefixes, error messages, limits) centralized in config.py.
7. **History Management**: ComfyClient maintains a bounded deque of recent generations (configurable via `history_limit`) for retrieval via `ttr!get-history`.
## Environment Variables
Required in `.env`:
- `DISCORD_BOT_TOKEN`: Discord bot authentication token
- `COMFY_SERVER`: ComfyUI server address (e.g., `localhost:8188` or `example.com:8188`)
Optional:
- `WORKFLOW_FILE`: Path to JSON workflow file to load at startup
- `COMFY_HISTORY_LIMIT`: Number of generations to keep in history (default: 10)
- `COMFY_OUTPUT_PATH`: Path to ComfyUI output directory (default: `C:\Users\ktrangia\Documents\ComfyUI\output`)
## Running the Bot
```bash
python bot.py
```
The bot will:
1. Load configuration from environment variables via `BotConfig.from_env()`
2. Create WorkflowStateManager and WorkflowManager instances
3. Initialize ComfyClient with injected dependencies
4. Load workflow from `WORKFLOW_FILE` if specified
5. Register all commands via `commands.register_all_commands()`
6. Start Discord bot and job queue
7. Listen for commands with prefix `ttr!`
## Development Commands
No build/test/lint commands exist. This is a standalone Python application.
To run: `python bot.py`
## Key Implementation Details
### ComfyUI Workflow Node Injection
When generating with workflows, WorkflowManager searches for specific node patterns:
- **Prompt**: Finds `CLIPTextEncode` nodes with `_meta.title` containing "Positive Prompt"
- **Negative Prompt**: Finds `CLIPTextEncode` nodes with `_meta.title` containing "Negative Prompt"
- **Input Image**: Finds `LoadImage` nodes and replaces the `image` input
- **Seeds**: Finds any node with `inputs.seed` or `inputs.noise_seed` and randomizes
This pattern-matching approach means workflows must follow naming conventions in their node titles for dynamic updates to work.
### Discord Command Pattern
Commands use a labelled parameter syntax: `ttr!generate prompt:<text> negative_prompt:<text>`
Parsing is handled by helpers in discord_utils.py (e.g., `parse_labeled_args()`). The bot splits on keyword markers (`prompt:`, `negative_prompt:`, `type:`, etc.) rather than traditional argparse. Case is preserved for prompts.
### Job Queue Mechanics
Jobs are dataclasses with `run: Callable[[], Awaitable[None]]` and a `label` for logging. The queue returns position on submit. Jobs capture their context (ctx, prompts) via lambda closures when submitted.
### Image/Video Output Handling
The `_general_generate` method in ComfyClient returns both images and videos. Videos are identified by file extension (mp4, webm, avi) in the history response. For videos, the bot reads the file from disk at the path specified by `COMFY_OUTPUT_PATH` rather than downloading via the API.
### Command Validation
The `@require_comfy_client` decorator (from discord_utils.py) validates that `bot.comfy` exists before executing commands. This eliminates repetitive validation code in every command handler.
### State Management
WorkflowStateManager maintains runtime workflow changes in memory with optional persistence to `current-workflow-changes.json`. The file is loaded on initialization if it exists, and saved automatically when changes are made.
## Configuration System
Configuration is managed via the `BotConfig` dataclass in config.py:
```python
from config import BotConfig
# Load from environment
config = BotConfig.from_env()
# Access configuration
server = config.comfy_server
history_limit = config.comfy_history_limit
output_path = config.comfy_output_path
```
All constants (command prefixes, error messages, defaults) are defined in config.py and imported where needed.
## Adding New Commands
To add a new command:
1. Create a new module in `commands/` (e.g., `commands/my_feature.py`)
2. Define a `setup_my_feature_commands(bot, config=None)` function
3. Use `@bot.command(name="...")` decorators to define commands
4. Use `@require_comfy_client` decorator if command needs ComfyClient
5. Import and call your setup function in `commands/__init__.py`'s `register_all_commands()`
Example:
```python
# commands/my_feature.py
from discord.ext import commands
from discord_utils import require_comfy_client
def setup_my_feature_commands(bot):
@bot.command(name="my-command")
@require_comfy_client
async def my_command(ctx: commands.Context):
await ctx.reply("Hello from my command!")
```
## Dependencies
From imports:
- discord.py
- aiohttp
- websockets
- python-dotenv (optional, for .env loading)
No requirements.txt exists. Install manually: `pip install discord.py aiohttp websockets python-dotenv`
## Code Organization Principles
The refactored codebase follows these principles:
1. **Single Responsibility**: Each module has one clear purpose
2. **Dependency Injection**: Dependencies passed via constructor, not created internally
3. **Configuration Centralization**: All configuration in config.py
4. **Command Separation**: Commands grouped by functionality in separate modules
5. **No Magic Strings**: Constants defined once in config.py
6. **Type Safety**: Modern Python type hints throughout (dict[str, Any] instead of Dict)
7. **Logging**: Using logger methods instead of print() statements

780
DEVELOPMENT.md Normal file
View File

@@ -0,0 +1,780 @@
# Development Guide
This guide explains how to add new commands, features, and modules to the Discord ComfyUI bot.
## Table of Contents
- [Adding a New Command](#adding-a-new-command)
- [Adding a New Feature Module](#adding-a-new-feature-module)
- [Adding Configuration Options](#adding-configuration-options)
- [Working with Workflows](#working-with-workflows)
- [Best Practices](#best-practices)
- [Common Patterns](#common-patterns)
- [Testing Your Changes](#testing-your-changes)
---
## Adding a New Command
Commands are organized in the `commands/` directory by functionality. Here's how to add a new command:
### Step 1: Choose the Right Module
Determine which existing command module your command belongs to:
- **generation.py** - Image/video generation commands
- **workflow.py** - Workflow template management
- **upload.py** - File upload commands
- **history.py** - History viewing and retrieval
- **workflow_changes.py** - Runtime workflow parameter management
If none fit, create a new module (see [Adding a New Feature Module](#adding-a-new-feature-module)).
### Step 2: Add Your Command Function
Edit the appropriate module in `commands/` and add your command to the `setup_*_commands()` function:
```python
# commands/generation.py
def setup_generation_commands(bot, config):
# ... existing commands ...
@bot.command(name="my-new-command", aliases=["mnc", "my-cmd"])
@require_comfy_client # Use this decorator if you need bot.comfy
async def my_new_command(ctx: commands.Context, *, args: str = "") -> None:
"""
Brief description of what your command does.
Usage:
ttr!my-new-command [arguments]
Longer description with examples and details.
"""
# Parse arguments if needed
if not args:
await ctx.reply("Please provide arguments!", mention_author=False)
return
try:
# Your command logic here
result = await bot.comfy.some_method(args)
# Send response
await ctx.reply(f"Success! Result: {result}", mention_author=False)
except Exception as exc:
logger.exception("Failed to execute my-new-command")
await ctx.reply(
f"An error occurred: {type(exc).__name__}: {exc}",
mention_author=False,
)
```
### Step 3: Import Required Dependencies
At the top of your command module, import what you need:
```python
import logging
from discord.ext import commands
from discord_utils import require_comfy_client, parse_labeled_args
from config import ARG_PROMPT_KEY, ARG_TYPE_KEY
from job_queue import Job
logger = logging.getLogger(__name__)
```
### Step 4: Test Your Command
Run the bot and test your command:
```bash
python bot.py
```
Then in Discord:
```
ttr!my-new-command test arguments
```
---
## Adding a New Feature Module
If your commands don't fit existing modules, create a new one:
### Step 1: Create the Module File
Create `commands/your_feature.py`:
```python
"""
commands/your_feature.py
========================
Description of what this module handles.
"""
from __future__ import annotations
import logging
from discord.ext import commands
from discord_utils import require_comfy_client
logger = logging.getLogger(__name__)
def setup_your_feature_commands(bot, config):
"""
Register your feature commands with the bot.
Parameters
----------
bot : commands.Bot
The Discord bot instance.
config : BotConfig
The bot configuration object.
"""
@bot.command(name="feature-command")
@require_comfy_client
async def feature_command(ctx: commands.Context, *, args: str = "") -> None:
"""Command description."""
await ctx.reply("Feature command executed!", mention_author=False)
@bot.command(name="another-command", aliases=["ac"])
async def another_command(ctx: commands.Context) -> None:
"""Another command description."""
await ctx.reply("Another command!", mention_author=False)
```
### Step 2: Register in commands/__init__.py
Edit `commands/__init__.py` to import and register your module:
```python
from .generation import setup_generation_commands
from .workflow import setup_workflow_commands
from .upload import setup_upload_commands
from .history import setup_history_commands
from .workflow_changes import setup_workflow_changes_commands
from .your_feature import setup_your_feature_commands # ADD THIS
def register_all_commands(bot, config):
"""Register all bot commands."""
setup_generation_commands(bot, config)
setup_workflow_commands(bot)
setup_upload_commands(bot)
setup_history_commands(bot)
setup_workflow_changes_commands(bot)
setup_your_feature_commands(bot, config) # ADD THIS
```
### Step 3: Update Documentation
Add your module to `CLAUDE.md`:
```markdown
### File Structure
```
commands/
├── __init__.py
├── generation.py
├── workflow.py
├── upload.py
├── history.py
├── workflow_changes.py
└── your_feature.py # Your new module
```
```
---
## Adding Configuration Options
Configuration is centralized in `config.py`. Here's how to add new options:
### Step 1: Add Constants (if needed)
Edit `config.py` and add constants in the appropriate section:
```python
# ========================================
# Your Feature Constants
# ========================================
MY_FEATURE_DEFAULT_VALUE = 42
"""Default value for my feature."""
MY_FEATURE_MAX_LIMIT = 100
"""Maximum limit for my feature."""
```
### Step 2: Add to BotConfig (if environment variable)
If your config comes from environment variables, add it to `BotConfig`:
```python
@dataclass
class BotConfig:
"""Configuration container for the Discord ComfyUI bot."""
discord_bot_token: str
comfy_server: str
comfy_output_path: str
comfy_history_limit: int
workflow_file: Optional[str] = None
my_feature_enabled: bool = False # ADD THIS
my_feature_value: int = MY_FEATURE_DEFAULT_VALUE # ADD THIS
```
### Step 3: Load in from_env()
Add loading logic in `BotConfig.from_env()`:
```python
@classmethod
def from_env(cls) -> BotConfig:
"""Create a BotConfig instance by loading from environment."""
# ... existing code ...
# Load your feature config
my_feature_enabled = os.getenv("MY_FEATURE_ENABLED", "false").lower() == "true"
try:
my_feature_value = int(os.getenv("MY_FEATURE_VALUE", str(MY_FEATURE_DEFAULT_VALUE)))
except ValueError:
my_feature_value = MY_FEATURE_DEFAULT_VALUE
return cls(
# ... existing parameters ...
my_feature_enabled=my_feature_enabled,
my_feature_value=my_feature_value,
)
```
### Step 4: Use in Your Commands
Access config in your commands:
```python
def setup_your_feature_commands(bot, config):
@bot.command(name="feature")
async def feature_command(ctx: commands.Context):
if not config.my_feature_enabled:
await ctx.reply("Feature is disabled!", mention_author=False)
return
value = config.my_feature_value
await ctx.reply(f"Feature value: {value}", mention_author=False)
```
### Step 5: Document the Environment Variable
Update `CLAUDE.md` and add to `.env.example` (if you create one):
```bash
# Feature Configuration
MY_FEATURE_ENABLED=true
MY_FEATURE_VALUE=42
```
---
## Working with Workflows
The bot has separate concerns for workflows:
- **WorkflowManager** (`workflow_manager.py`) - Template storage and node manipulation
- **WorkflowStateManager** (`workflow_state.py`) - Runtime state (prompt, negative_prompt, input_image)
- **ComfyClient** (`comfy_client.py`) - Uses both managers to generate images
### Adding New Workflow Node Types
If you need to manipulate new types of nodes in workflows:
#### Step 1: Add Method to WorkflowManager
Edit `workflow_manager.py`:
```python
def find_and_replace_my_node(
self, workflow: Dict[str, Any], my_value: str
) -> Dict[str, Any]:
"""
Find and replace my custom node type.
This searches for nodes of a specific class_type and updates their inputs.
Parameters
----------
workflow : Dict[str, Any]
The workflow definition to modify.
my_value : str
The value to inject.
Returns
-------
Dict[str, Any]
The modified workflow.
"""
for node_id, node in workflow.items():
if node.get("class_type") == "MyCustomNodeType" and node.get("inputs"):
# Check metadata for specific node identification
meta = node.get("_meta", {})
if "My Custom Node" in meta.get("title", ""):
workflow[node_id]["inputs"]["my_input"] = my_value
logger.debug("Replaced my_value in node %s", node_id)
return workflow
```
#### Step 2: Add to apply_state_changes()
Update `apply_state_changes()` to include your new manipulation:
```python
def apply_state_changes(
self,
workflow: Dict[str, Any],
prompt: Optional[str] = None,
negative_prompt: Optional[str] = None,
input_image: Optional[str] = None,
my_custom_value: Optional[str] = None, # ADD THIS
randomize_seed: bool = True,
) -> Dict[str, Any]:
"""Apply multiple state changes to a workflow in one pass."""
if randomize_seed:
workflow = self.find_and_replace_seed(workflow)
if prompt is not None:
workflow = self.find_and_replace_prompt(workflow, prompt)
if negative_prompt is not None:
workflow = self.find_and_replace_negative_prompt(workflow, negative_prompt)
if input_image is not None:
workflow = self.find_and_replace_input_image(workflow, input_image)
# ADD THIS
if my_custom_value is not None:
workflow = self.find_and_replace_my_node(workflow, my_custom_value)
return workflow
```
#### Step 3: Add State to WorkflowStateManager
Edit `workflow_state.py` to track the new state:
```python
def __init__(self, state_file: Optional[str] = None):
"""Initialize the workflow state manager."""
self._state: Dict[str, Any] = {
"prompt": None,
"negative_prompt": None,
"input_image": None,
"my_custom_value": None, # ADD THIS
}
# ... rest of init ...
def set_my_custom_value(self, value: str) -> None:
"""Set the custom value."""
self._state["my_custom_value"] = value
if self._state_file:
try:
self.save_to_file()
except Exception:
pass
def get_my_custom_value(self) -> Optional[str]:
"""Get the custom value."""
return self._state.get("my_custom_value")
```
#### Step 4: Use in ComfyClient
The ComfyClient will automatically use your new state if you update `generate_image_with_workflow()`:
```python
async def generate_image_with_workflow(self) -> tuple[List[bytes], List[dict[str, Any]], str]:
# ... existing code ...
# Get current state changes
changes = self.state_manager.get_changes()
prompt = changes.get("prompt")
negative_prompt = changes.get("negative_prompt")
input_image = changes.get("input_image")
my_custom_value = changes.get("my_custom_value") # ADD THIS
# Apply changes using WorkflowManager
workflow = self.workflow_manager.apply_state_changes(
workflow,
prompt=prompt,
negative_prompt=negative_prompt,
input_image=input_image,
my_custom_value=my_custom_value, # ADD THIS
randomize_seed=True,
)
# ... rest of method ...
```
---
## Best Practices
### Command Design
1. **Use descriptive names**: `ttr!generate` is better than `ttr!gen` (but provide aliases)
2. **Validate inputs early**: Check arguments before starting long operations
3. **Provide clear feedback**: Tell users what's happening and when it's done
4. **Handle errors gracefully**: Catch exceptions and show user-friendly messages
5. **Use decorators**: `@require_comfy_client` eliminates boilerplate
### Code Organization
1. **One responsibility per module**: Don't mix unrelated commands
2. **Keep functions small**: If a function is > 50 lines, consider splitting it
3. **Use type hints**: Help future developers understand your code
4. **Document with docstrings**: Explain what, why, and how
### Discord Best Practices
1. **Use `mention_author=False`**: Prevents spam from @mentions
2. **Use `delete_after=X`**: For temporary status messages
3. **Use `ephemeral=True`**: For interaction responses (buttons/modals)
4. **Limit file attachments**: Discord has a 4-file limit (use `MAX_IMAGES_PER_RESPONSE`)
### Performance
1. **Use job queue for long operations**: Queue generation requests
2. **Use typing indicator**: `async with ctx.typing():` shows bot is working
3. **Batch operations**: Don't send 10 separate messages when 1 will do
4. **Close resources**: Always close aiohttp sessions, file handles
---
## Common Patterns
### Pattern 1: Labeled Argument Parsing
For commands with `key:value` syntax:
```python
from discord_utils import parse_labeled_args
from config import ARG_PROMPT_KEY, ARG_TYPE_KEY
@bot.command(name="my-command")
async def my_command(ctx: commands.Context, *, args: str = ""):
# Parse labeled arguments
parsed = parse_labeled_args(args, [ARG_PROMPT_KEY, ARG_TYPE_KEY])
prompt = parsed.get("prompt") # None if not provided
image_type = parsed.get("type") or "input" # Default to "input"
if not prompt:
await ctx.reply("Please provide a prompt!", mention_author=False)
return
```
### Pattern 2: Queued Job Execution
For long-running operations:
```python
from job_queue import Job
@bot.command(name="long-operation")
@require_comfy_client
async def long_operation(ctx: commands.Context, *, args: str = ""):
try:
# Define the job function
async def _run_job():
async with ctx.typing():
result = await bot.comfy.some_long_operation(args)
await ctx.reply(f"Done! Result: {result}", mention_author=False)
# Submit to queue
position = await bot.jobq.submit(
Job(
label=f"long-operation:{ctx.author.id}",
run=_run_job,
)
)
await ctx.reply(
f"Queued ✅ (position: {position})",
mention_author=False,
delete_after=2.0,
)
except Exception as exc:
logger.exception("Failed to queue long operation")
await ctx.reply(
f"An error occurred: {type(exc).__name__}: {exc}",
mention_author=False,
)
```
### Pattern 3: File Attachments
For uploading files to ComfyUI:
```python
@bot.command(name="upload-and-process")
@require_comfy_client
async def upload_and_process(ctx: commands.Context):
if not ctx.message.attachments:
await ctx.reply("Please attach a file!", mention_author=False)
return
for attachment in ctx.message.attachments:
try:
# Download attachment
data = await attachment.read()
# Upload to ComfyUI
result = await bot.comfy.upload_image(
data,
attachment.filename,
image_type="input",
)
# Process the uploaded file
filename = result.get("name")
await ctx.reply(f"Uploaded: {filename}", mention_author=False)
except Exception as exc:
logger.exception("Failed to process attachment")
await ctx.reply(
f"Failed: {attachment.filename}: {exc}",
mention_author=False,
)
```
### Pattern 4: Interactive UI (Buttons)
For adding buttons to messages:
```python
from discord.ui import View, Button
import discord
class MyView(View):
def __init__(self, data: str):
super().__init__(timeout=None)
self.data = data
@discord.ui.button(label="Click Me", style=discord.ButtonStyle.primary)
async def button_callback(
self, interaction: discord.Interaction, button: discord.ui.Button
):
# Handle button click
await interaction.response.send_message(
f"You clicked! Data: {self.data}",
ephemeral=True,
)
@bot.command(name="interactive")
async def interactive(ctx: commands.Context):
view = MyView(data="example")
await ctx.reply("Click the button:", view=view, mention_author=False)
```
### Pattern 5: Using Configuration
Access bot configuration in commands:
```python
def setup_my_commands(bot, config):
@bot.command(name="check-config")
async def check_config(ctx: commands.Context):
# Access config values
server = config.comfy_server
output_path = config.comfy_output_path
await ctx.reply(
f"Server: {server}\nOutput: {output_path}",
mention_author=False,
)
```
---
## Testing Your Changes
### Manual Testing Checklist
1. **Start the bot**: `python bot.py`
- Verify no import errors
- Check configuration loads correctly
- Confirm all commands register
2. **Test basic functionality**:
```
ttr!test
ttr!help
ttr!your-new-command
```
3. **Test error handling**:
- Run command with missing arguments
- Run command with invalid arguments
- Test when ComfyUI is unavailable (if applicable)
4. **Test edge cases**:
- Very long inputs
- Special characters
- Concurrent command execution
- Commands while queue is full
### Syntax Validation
Check for syntax errors without running:
```bash
python -m py_compile bot.py
python -m py_compile commands/your_feature.py
```
### Check Imports
Verify all imports work:
```bash
python -c "from commands.your_feature import setup_your_feature_commands; print('OK')"
```
### Code Style
Follow these conventions:
- **Indentation**: 4 spaces (no tabs)
- **Line length**: Max 100 characters (documentation can be longer)
- **Docstrings**: Use Google style or NumPy style (match existing code)
- **Imports**: Group stdlib, third-party, local (separated by blank lines)
- **Type hints**: Use modern syntax (`dict[str, Any]` not `Dict[str, Any]`)
---
## Example: Complete Feature Addition
Here's a complete example adding a "status" command:
### 1. Create commands/status.py
```python
"""
commands/status.py
==================
Bot status and diagnostics commands.
"""
from __future__ import annotations
import logging
import asyncio
from discord.ext import commands
from discord_utils import require_comfy_client
logger = logging.getLogger(__name__)
def setup_status_commands(bot, config):
"""Register status commands with the bot."""
@bot.command(name="status", aliases=["s", "stat"])
async def status_command(ctx: commands.Context) -> None:
"""
Show bot status and queue information.
Usage:
ttr!status
Displays:
- Bot connection status
- ComfyUI connection status
- Current queue size
- Configuration info
"""
# Check bot status
latency_ms = round(bot.latency * 1000)
# Check queue
if hasattr(bot, "jobq"):
queue_size = await bot.jobq.get_queue_size()
else:
queue_size = 0
# Check ComfyUI
comfy_status = "✅ Connected" if hasattr(bot, "comfy") else "❌ Not configured"
# Build status message
status_msg = [
"**Bot Status**",
f"• Latency: {latency_ms}ms",
f"• Queue size: {queue_size}",
f"• ComfyUI: {comfy_status}",
f"• Server: {config.comfy_server}",
]
await ctx.reply("\n".join(status_msg), mention_author=False)
@bot.command(name="ping")
async def ping_command(ctx: commands.Context) -> None:
"""
Check bot responsiveness.
Usage:
ttr!ping
"""
latency_ms = round(bot.latency * 1000)
await ctx.reply(f"🏓 Pong! Latency: {latency_ms}ms", mention_author=False)
```
### 2. Register in commands/__init__.py
```python
from .status import setup_status_commands
def register_all_commands(bot, config):
# ... existing registrations ...
setup_status_commands(bot, config)
```
### 3. Update CLAUDE.md
```markdown
- **commands/status.py** - Bot status and diagnostics commands
```
### 4. Test
```bash
python bot.py
```
In Discord:
```
ttr!status
ttr!ping
ttr!s (alias test)
```
---
## Getting Help
If you're stuck:
1. **Check CLAUDE.md**: Architecture and patterns documented there
2. **Read existing commands**: See how similar features are implemented
3. **Check logs**: Run bot and check console output for errors
4. **Test incrementally**: Add small pieces and test frequently
Remember: The refactored architecture makes adding features straightforward. Follow the patterns, and your code will fit right in! 🚀

196
QUICK_START.md Normal file
View File

@@ -0,0 +1,196 @@
# Quick Start Guide
Quick reference for common development tasks.
## Add a Simple Command
```python
# commands/your_module.py
def setup_your_commands(bot, config):
@bot.command(name="hello")
async def hello(ctx):
await ctx.reply("Hello!", mention_author=False)
```
Register it:
```python
# commands/__init__.py
from .your_module import setup_your_commands
def register_all_commands(bot, config):
# ... existing ...
setup_your_commands(bot, config)
```
## Add a Command That Uses ComfyUI
```python
from discord_utils import require_comfy_client
@bot.command(name="my-cmd")
@require_comfy_client # Validates bot.comfy exists
async def my_cmd(ctx):
result = await bot.comfy.some_method()
await ctx.reply(f"Result: {result}", mention_author=False)
```
## Add a Long-Running Command
```python
from job_queue import Job
@bot.command(name="generate")
@require_comfy_client
async def generate(ctx, *, args: str = ""):
async def _run():
async with ctx.typing():
result = await bot.comfy.generate_image(args)
await ctx.reply(f"Done! {result}", mention_author=False)
pos = await bot.jobq.submit(Job(label="generate", run=_run))
await ctx.reply(f"Queued ✅ (position: {pos})", mention_author=False)
```
## Add Configuration
```python
# config.py
MY_FEATURE_ENABLED = True
@dataclass
class BotConfig:
# ... existing fields ...
my_feature_enabled: bool = MY_FEATURE_ENABLED
@classmethod
def from_env(cls) -> BotConfig:
# ... existing code ...
my_feature = os.getenv("MY_FEATURE_ENABLED", "true").lower() == "true"
return cls(
# ... existing params ...
my_feature_enabled=my_feature
)
```
Use in commands:
```python
def setup_my_commands(bot, config):
@bot.command(name="feature")
async def feature(ctx):
if config.my_feature_enabled:
await ctx.reply("Enabled!", mention_author=False)
```
## Parse Command Arguments
```python
from discord_utils import parse_labeled_args
from config import ARG_PROMPT_KEY, ARG_TYPE_KEY
@bot.command(name="cmd")
async def cmd(ctx, *, args: str = ""):
# Parse "prompt:text type:value" format
parsed = parse_labeled_args(args, [ARG_PROMPT_KEY, ARG_TYPE_KEY])
prompt = parsed.get("prompt")
img_type = parsed.get("type") or "input" # Default
if not prompt:
await ctx.reply("Missing prompt!", mention_author=False)
return
```
## Handle File Uploads
```python
@bot.command(name="upload")
async def upload(ctx):
if not ctx.message.attachments:
await ctx.reply("Attach a file!", mention_author=False)
return
for attachment in ctx.message.attachments:
data = await attachment.read()
# Process data...
```
## Access Bot State
```python
@bot.command(name="info")
async def info(ctx):
# Queue size
queue_size = await bot.jobq.get_queue_size()
# Config
server = bot.config.comfy_server
# Last generation
last_id = bot.comfy.last_prompt_id
await ctx.reply(
f"Queue: {queue_size}, Server: {server}, Last: {last_id}",
mention_author=False
)
```
## Add Buttons
```python
from discord.ui import View, Button
import discord
class MyView(View):
@discord.ui.button(label="Click", style=discord.ButtonStyle.primary)
async def button_callback(self, interaction, button):
await interaction.response.send_message("Clicked!", ephemeral=True)
@bot.command(name="interactive")
async def interactive(ctx):
await ctx.reply("Press button:", view=MyView(), mention_author=False)
```
## Common Imports
```python
from __future__ import annotations
import logging
from discord.ext import commands
from discord_utils import require_comfy_client
from config import ARG_PROMPT_KEY, ARG_TYPE_KEY
from job_queue import Job
logger = logging.getLogger(__name__)
```
## Test Your Changes
```bash
# Syntax check
python -m py_compile commands/your_module.py
# Run bot
python bot.py
# Test in Discord
ttr!your-command
```
## File Structure
```
commands/
├── __init__.py # Register all commands here
├── generation.py # Generation commands
├── workflow.py # Workflow management
├── upload.py # File uploads
├── history.py # History retrieval
├── workflow_changes.py # State management
└── your_module.py # Your new module
```
## Need More Details?
See `DEVELOPMENT.md` for comprehensive guide with examples, patterns, and best practices.

345
README.md Normal file
View File

@@ -0,0 +1,345 @@
# Discord ComfyUI Bot
A Discord bot that integrates with ComfyUI to generate AI images and videos through Discord commands.
## Features
- 🎨 **Image Generation** - Generate images using simple prompts or complex workflows
- 🎬 **Video Generation** - Support for video output workflows
- 📝 **Workflow Management** - Load, modify, and execute ComfyUI workflows
- 📤 **Image Upload** - Upload reference images directly through Discord
- 📊 **Generation History** - Track and retrieve past generations
- ⚙️ **Runtime Workflow Modification** - Change prompts, negative prompts, and input images on the fly
- 🔄 **Job Queue System** - Sequential execution prevents server overload
## Quick Start
### Prerequisites
- Python 3.9+
- Discord Bot Token ([create one here](https://discord.com/developers/applications))
- ComfyUI Server running and accessible
- Required packages: `discord.py`, `aiohttp`, `websockets`, `python-dotenv`
### Installation
1. **Clone or download this repository**
2. **Install dependencies**:
```bash
pip install discord.py aiohttp websockets python-dotenv
```
3. **Create `.env` file** with your credentials:
```bash
DISCORD_BOT_TOKEN=your_discord_bot_token_here
COMFY_SERVER=localhost:8188
```
4. **Run the bot**:
```bash
python bot.py
```
## Configuration
Create a `.env` file in the project root:
```bash
# Required
DISCORD_BOT_TOKEN=your_discord_bot_token
COMFY_SERVER=localhost:8188
# Optional
WORKFLOW_FILE=wan2.2-fast.json
COMFY_HISTORY_LIMIT=10
COMFY_OUTPUT_PATH=C:\Users\YourName\Documents\ComfyUI\output
```
### Configuration Options
| Variable | Required | Default | Description |
|----------|----------|---------|-------------|
| `DISCORD_BOT_TOKEN` | ✅ Yes | - | Discord bot authentication token |
| `COMFY_SERVER` | ✅ Yes | - | ComfyUI server address (host:port) |
| `WORKFLOW_FILE` | ❌ No | - | Path to workflow JSON to load at startup |
| `COMFY_HISTORY_LIMIT` | ❌ No | `10` | Number of generations to keep in history |
| `COMFY_OUTPUT_PATH` | ❌ No | `C:\Users\...\ComfyUI\output` | Path to ComfyUI output directory |
## Usage
All commands use the `ttr!` prefix.
### Basic Commands
```bash
# Test if bot is working
ttr!test
# Generate an image with a prompt
ttr!generate prompt:a beautiful sunset over mountains
# Generate with negative prompt
ttr!generate prompt:a cat negative_prompt:blurry, low quality
# Execute loaded workflow
ttr!workflow-gen
# Queue multiple workflow runs
ttr!workflow-gen queue:5
```
### Workflow Management
```bash
# Load a workflow from file
ttr!workflow-load path/to/workflow.json
# Or attach a JSON file to the message:
ttr!workflow-load
[Attach: my_workflow.json]
# View current workflow changes
ttr!get-current-workflow-changes type:all
# Set workflow parameters
ttr!set-current-workflow-changes type:prompt A new prompt
ttr!set-current-workflow-changes type:negative_prompt blurry
ttr!set-current-workflow-changes type:input_image input/image.png
```
### Image Upload
```bash
# Upload images to ComfyUI
ttr!upload
[Attach: image1.png, image2.png]
# Upload to specific folder
ttr!upload type:temp
[Attach: reference.png]
```
### History
```bash
# View recent generations
ttr!history
# Retrieve images from a past generation
ttr!get-history <prompt_id>
ttr!get-history 1 # By index
```
### Command Aliases
Many commands have shorter aliases:
- `ttr!generate` → `ttr!gen`
- `ttr!workflow-gen` → `ttr!wfg`
- `ttr!workflow-load` → `ttr!wfl`
- `ttr!get-history` → `ttr!gh`
- `ttr!get-current-workflow-changes` → `ttr!gcwc`
- `ttr!set-current-workflow-changes` → `ttr!scwc`
## Architecture
The bot is organized into focused, maintainable modules:
```
the-third-rev/
├── config.py # Configuration and constants
├── job_queue.py # Job queue system
├── workflow_manager.py # Workflow manipulation
├── workflow_state.py # Runtime state management
├── discord_utils.py # Discord utilities
├── bot.py # Main entry point (~150 lines)
├── comfy_client.py # ComfyUI API client (~650 lines)
└── commands/ # Command handlers
├── generation.py # Image/video generation
├── workflow.py # Workflow management
├── upload.py # File uploads
├── history.py # History retrieval
└── workflow_changes.py # State management
```
### Key Design Principles
- **Dependency Injection** - Dependencies passed via constructor
- **Single Responsibility** - Each module has one clear purpose
- **Configuration Centralization** - All config in `config.py`
- **Command Separation** - Commands grouped by functionality
- **Type Safety** - Modern Python type hints throughout
## Development
### Adding a New Command
See `QUICK_START.md` for quick examples or `DEVELOPMENT.md` for comprehensive guide.
Basic example:
```python
# commands/your_module.py
def setup_your_commands(bot, config):
@bot.command(name="hello")
async def hello(ctx):
await ctx.reply("Hello!", mention_author=False)
```
Register in `commands/__init__.py`:
```python
from .your_module import setup_your_commands
def register_all_commands(bot, config):
# ... existing ...
setup_your_commands(bot, config)
```
### Documentation
- **README.md** (this file) - Project overview and setup
- **QUICK_START.md** - Quick reference for common tasks
- **DEVELOPMENT.md** - Comprehensive development guide
- **CLAUDE.md** - Architecture documentation for Claude Code
## Workflow System
The bot supports two generation modes:
### 1. Prompt Mode (Simple)
Uses a workflow template with a KSampler node:
```bash
ttr!generate prompt:a cat negative_prompt:blurry
```
The bot automatically finds and replaces:
- Positive prompt in CLIPTextEncode node (title: "Positive Prompt")
- Negative prompt in CLIPTextEncode node (title: "Negative Prompt")
- Seed values (randomized each run)
### 2. Workflow Mode (Advanced)
Execute full workflow with runtime modifications:
```bash
# Set workflow parameters
ttr!set-current-workflow-changes type:prompt A beautiful landscape
ttr!set-current-workflow-changes type:input_image input/reference.png
# Execute workflow
ttr!workflow-gen
```
The bot:
1. Loads the workflow template
2. Applies runtime changes from WorkflowStateManager
3. Randomizes seeds
4. Executes on ComfyUI server
5. Returns images/videos
### Node Naming Conventions
For workflows to work with dynamic updates, nodes must follow naming conventions:
- **Positive Prompt**: CLIPTextEncode node with title containing "Positive Prompt"
- **Negative Prompt**: CLIPTextEncode node with title containing "Negative Prompt"
- **Input Image**: LoadImage node (any title)
- **Seeds**: Any node with `inputs.seed` or `inputs.noise_seed`
## Troubleshooting
### Bot won't start
**Issue**: `AttributeError: module 'queue' has no attribute 'SimpleQueue'`
**Solution**: This was fixed by renaming `queue.py` to `job_queue.py`. Make sure you're using the latest version.
### ComfyUI connection issues
**Issue**: `ComfyUI client is not configured`
**Solution**:
1. Check `.env` file has `DISCORD_BOT_TOKEN` and `COMFY_SERVER`
2. Verify ComfyUI server is running
3. Test connection: `curl http://localhost:8188`
### Commands not responding
**Issue**: Bot online but commands don't work
**Solution**:
1. Check bot has Message Content Intent enabled in Discord Developer Portal
2. Verify bot has permissions in Discord server
3. Check console logs for errors
### Video files not found
**Issue**: `Failed to read video file`
**Solution**:
1. Set `COMFY_OUTPUT_PATH` in `.env` to your ComfyUI output directory
2. Check path uses correct format for your OS
## Advanced Usage
### Batch Generation
Queue multiple workflow runs:
```bash
ttr!workflow-gen queue:10
```
Each run uses randomized seeds for variation.
### Custom Workflows
1. Design workflow in ComfyUI
2. Export as API format (Save → API Format)
3. Load in bot:
```bash
ttr!workflow-load path/to/workflow.json
```
4. Modify at runtime:
```bash
ttr!set-current-workflow-changes type:prompt My prompt
ttr!workflow-gen
```
### State Persistence
Workflow changes are automatically saved to `current-workflow-changes.json` and persist across bot restarts.
## Contributing
We welcome contributions! Please:
1. Read `DEVELOPMENT.md` for coding guidelines
2. Follow existing code style and patterns
3. Test your changes thoroughly
4. Update documentation as needed
## License
[Your License Here]
## Support
For issues or questions:
- Check the troubleshooting section above
- Review `DEVELOPMENT.md` for implementation details
- Check ComfyUI documentation for workflow issues
- Open an issue on GitHub
## Credits
Built with:
- [discord.py](https://github.com/Rapptz/discord.py) - Discord API wrapper
- [ComfyUI](https://github.com/comfyanonymous/ComfyUI) - Stable Diffusion GUI
- [aiohttp](https://github.com/aio-libs/aiohttp) - Async HTTP client
- [websockets](https://github.com/python-websockets/websockets) - WebSocket implementation

156
backfill_image_data.py Normal file
View File

@@ -0,0 +1,156 @@
"""
backfill_image_data.py
======================
One-shot script to download image bytes from Discord and store them in
input_images.db for rows that currently have image_data = NULL.
These rows were created before the BLOB-storage migration, so their bytes
were never persisted. The script re-fetches each bot-reply message from
Discord and writes the raw attachment bytes back into the DB.
Rows with bot_reply_id = 0 (web uploads that pre-date the migration) have
no Discord source and are skipped — re-upload them via the web UI to
backfill.
Usage
-----
python backfill_image_data.py
Requires:
DISCORD_BOT_TOKEN in .env (same token the bot uses)
"""
from __future__ import annotations
import asyncio
import logging
import sqlite3
import discord
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
from config import BotConfig
from input_image_db import DB_PATH
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
def _load_null_rows() -> list[dict]:
"""Return all rows that are missing image_data."""
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT id, bot_reply_id, channel_id, filename"
" FROM input_images WHERE image_data IS NULL"
).fetchall()
conn.close()
return [dict(r) for r in rows]
def _save_image_data(row_id: int, data: bytes) -> None:
conn = sqlite3.connect(str(DB_PATH))
conn.execute("UPDATE input_images SET image_data = ? WHERE id = ?", (data, row_id))
conn.commit()
conn.close()
async def _backfill(client: discord.Client) -> None:
rows = _load_null_rows()
discord_rows = [r for r in rows if r["bot_reply_id"] != 0]
web_rows = [r for r in rows if r["bot_reply_id"] == 0]
logger.info(
"Rows missing image_data: %d total (%d from Discord, %d web-uploads skipped)",
len(rows), len(discord_rows), len(web_rows),
)
if web_rows:
logger.info(
"Skipped row IDs (no Discord source — re-upload via web UI): %s",
[r["id"] for r in web_rows],
)
if not discord_rows:
logger.info("Nothing to fetch. Exiting.")
return
ok = 0
failed = 0
for row in discord_rows:
row_id = row["id"]
ch_id = row["channel_id"]
msg_id = row["bot_reply_id"]
filename = row["filename"]
try:
channel = client.get_channel(ch_id) or await client.fetch_channel(ch_id)
message = await channel.fetch_message(msg_id)
attachment = next(
(a for a in message.attachments if a.filename == filename), None
)
if attachment is None:
logger.warning(
"Row %d: attachment '%s' not found on message %d — skipping",
row_id, filename, msg_id,
)
failed += 1
continue
data = await attachment.read()
_save_image_data(row_id, data)
logger.info("Row %d: saved '%s' (%d bytes)", row_id, filename, len(data))
ok += 1
except discord.NotFound:
logger.warning("Row %d: message %d not found (deleted?) — skipping", row_id, msg_id)
failed += 1
except discord.Forbidden:
logger.warning("Row %d: no access to channel %d — skipping", row_id, ch_id)
failed += 1
except Exception as exc:
logger.error("Row %d: unexpected error — %s", row_id, exc)
failed += 1
logger.info(
"Done. %d saved, %d failed/skipped, %d web-upload rows not touched.",
ok, failed, len(web_rows),
)
async def _main(token: str) -> None:
intents = discord.Intents.none() # no gateway events needed beyond connect
client = discord.Client(intents=intents)
@client.event
async def on_ready():
logger.info("Logged in as %s", client.user)
try:
await _backfill(client)
finally:
await client.close()
await client.start(token)
def main() -> None:
try:
config = BotConfig.from_env()
except RuntimeError as exc:
logger.error("Config error: %s", exc)
return
asyncio.run(_main(config.discord_bot_token))
if __name__ == "__main__":
main()

218
bot.py Normal file
View File

@@ -0,0 +1,218 @@
"""
bot.py
======
Discord bot entry point. In WEB_ENABLED mode, also starts a FastAPI/Uvicorn
web server in the same asyncio event loop via asyncio.gather.
Jobs are submitted directly to ComfyUI — no internal SerialJobQueue.
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
import discord
from discord.ext import commands
from comfy_client import ComfyClient
from config import BotConfig, COMMAND_PREFIX
import generation_db
from input_image_db import init_db, get_all_images
from status_monitor import StatusMonitor
from workflow_manager import WorkflowManager
from workflow_state import WorkflowStateManager
from commands import register_all_commands, CustomHelpCommand
from commands.input_images import PersistentSetInputView
from commands.server import autostart_comfy
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
_PROJECT_ROOT = Path(__file__).resolve().parent
# ---------------------------------------------------------------------------
# Bot setup
# ---------------------------------------------------------------------------
def get_prefix(bot, message):
"""Dynamic command prefix getter."""
msg = message.content.lower()
if msg.startswith(COMMAND_PREFIX):
return COMMAND_PREFIX
return COMMAND_PREFIX
intents = discord.Intents.default()
intents.message_content = True
intents.guilds = True
bot = commands.Bot(
command_prefix=get_prefix,
intents=intents,
help_command=CustomHelpCommand(),
)
# ---------------------------------------------------------------------------
# Event handlers
# ---------------------------------------------------------------------------
@bot.event
async def on_ready() -> None:
logger.info("Logged in as %s (ID: %s)", bot.user, bot.user.id)
if not hasattr(bot, "start_time"):
bot.start_time = datetime.now(timezone.utc)
cfg = getattr(bot, "config", None)
if cfg:
for row in get_all_images():
view = PersistentSetInputView(bot, cfg, row["id"])
bot.add_view(view, message_id=row["bot_reply_id"])
asyncio.create_task(autostart_comfy(cfg))
if not hasattr(bot, "status_monitor"):
log_ch = getattr(getattr(bot, "config", None), "log_channel_id", None)
if log_ch:
bot.status_monitor = StatusMonitor(bot, log_ch)
if hasattr(bot, "status_monitor"):
await bot.status_monitor.start()
@bot.event
async def on_disconnect() -> None:
logger.info("Discord connection closed")
@bot.event
async def on_resumed() -> None:
logger.info("Discord session resumed")
# ---------------------------------------------------------------------------
# Startup helpers
# ---------------------------------------------------------------------------
async def create_comfy(config: BotConfig) -> ComfyClient:
state_manager = WorkflowStateManager(state_file="current-workflow-changes.json")
workflow_manager = WorkflowManager()
return ComfyClient(
server_address=config.comfy_server,
workflow_manager=workflow_manager,
state_manager=state_manager,
logger=logger,
history_limit=config.comfy_history_limit,
output_path=config.comfy_output_path,
)
def _try_autoload_last_workflow(client: ComfyClient) -> None:
"""Re-load the last used workflow from the workflows/ folder on startup."""
last_wf = client.state_manager.get_last_workflow_file()
if not last_wf:
return
wf_path = _PROJECT_ROOT / "workflows" / last_wf
if not wf_path.exists():
logger.warning("Last workflow file not found: %s", wf_path)
return
try:
with open(wf_path, "r", encoding="utf-8") as f:
workflow = json.load(f)
# Restore template without clearing overrides on restart
client.workflow_manager.set_workflow_template(workflow)
logger.info("Auto-loaded last workflow: %s", last_wf)
except Exception as exc:
logger.error("Failed to auto-load workflow %s: %s", last_wf, exc)
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
async def main() -> None:
try:
config: BotConfig = BotConfig.from_env()
logger.info("Configuration loaded")
except RuntimeError as exc:
logger.error("Configuration error: %s", exc)
return
bot.comfy = await create_comfy(config)
bot.config = config
# Auto-load last workflow (restores template without clearing overrides)
_try_autoload_last_workflow(bot.comfy)
# Fallback: WORKFLOW_FILE env var
if not bot.comfy.get_workflow_template() and config.workflow_file:
try:
bot.comfy.load_workflow_from_file(config.workflow_file)
logger.info("Loaded workflow from %s", config.workflow_file)
except Exception as exc:
logger.error("Failed to load workflow %s: %s", config.workflow_file, exc)
from user_state_registry import UserStateRegistry
bot.user_registry = UserStateRegistry(
settings_dir=_PROJECT_ROOT / "user_settings",
default_workflow=bot.comfy.get_workflow_template(),
)
init_db()
generation_db.init_db(_PROJECT_ROOT / "generation_history.db")
register_all_commands(bot, config)
logger.info("All commands registered")
coroutines = [bot.start(config.discord_bot_token)]
if config.web_enabled:
try:
import uvicorn
from web.deps import set_bot
from web.app import create_app
set_bot(bot)
fastapi_app = create_app()
uvi_config = uvicorn.Config(
fastapi_app,
host=config.web_host,
port=config.web_port,
log_level="info",
loop="none", # use existing event loop
)
uvi_server = uvicorn.Server(uvi_config)
coroutines.append(uvi_server.serve())
logger.info(
"Web UI enabled at http://%s:%d", config.web_host, config.web_port
)
except ImportError:
logger.warning(
"uvicorn or fastapi not installed — web UI disabled. "
"pip install fastapi uvicorn[standard]"
)
try:
await asyncio.gather(*coroutines)
finally:
if hasattr(bot, "status_monitor"):
await bot.status_monitor.stop()
if hasattr(bot, "comfy"):
await bot.comfy.close()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Received interrupt, shutting down…")

604
comfy_client.py Normal file
View File

@@ -0,0 +1,604 @@
"""
comfy_client.py
================
Asynchronous client for the ComfyUI API.
Wraps ComfyUI's REST and WebSocket endpoints. Workflow template injection
is now handled by :class:`~workflow_inspector.WorkflowInspector`, so this
class only needs to:
1. Accept a workflow template (delegated to WorkflowManager).
2. Accept runtime overrides (delegated to WorkflowStateManager).
3. Build the final workflow via inspector.inject_overrides().
4. Queue it to ComfyUI, wait for completion via WebSocket, fetch outputs.
A ``{prompt_id: callback}`` map is maintained for future WebSocket
broadcasting (web UI phase). Discord commands still use the synchronous
await-and-return model.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from typing import Any, Callable, Dict, List, Optional
import aiohttp
import websockets
from workflow_inspector import WorkflowInspector
class ComfyClient:
"""
Asynchronous ComfyUI client.
Parameters
----------
server_address : str
``hostname:port`` of the ComfyUI server.
workflow_manager : WorkflowManager
Template storage (injected).
state_manager : WorkflowStateManager
Runtime overrides (injected).
logger : Optional[logging.Logger]
Logger for debug/info messages.
history_limit : int
Max recent generations to keep in the in-memory deque.
"""
def __init__(
self,
server_address: str,
workflow_manager,
state_manager,
logger: Optional[logging.Logger] = None,
*,
history_limit: int = 10,
output_path: Optional[str] = None,
) -> None:
self.server_address = server_address.strip().rstrip("/")
self.client_id = str(uuid.uuid4())
self._session: Optional[aiohttp.ClientSession] = None
self.protocol = "http"
self.ws_protocol = "ws"
self.workflow_manager = workflow_manager
self.state_manager = state_manager
self.inspector = WorkflowInspector()
self.output_path = output_path
# prompt_id → asyncio.Future for web-UI broadcast (Phase 4)
self._pending_callbacks: Dict[str, Callable] = {}
from collections import deque
self._history = deque(maxlen=history_limit)
self.last_prompt_id: Optional[str] = None
self.last_seed: Optional[int] = None
self.total_generated: int = 0
self.logger = logger if logger else logging.getLogger(__name__)
# ------------------------------------------------------------------
# Session
# ------------------------------------------------------------------
@property
def session(self) -> aiohttp.ClientSession:
"""Lazily create and return an aiohttp session."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
# ------------------------------------------------------------------
# Low-level REST helpers
# ------------------------------------------------------------------
async def _queue_prompt(
self,
prompt: dict[str, Any],
prompt_id: str,
ws_client_id: str | None = None,
) -> dict[str, Any]:
"""Submit a workflow to the ComfyUI queue."""
payload = {
"prompt": prompt,
"client_id": ws_client_id if ws_client_id is not None else self.client_id,
"prompt_id": prompt_id,
}
url = f"{self.protocol}://{self.server_address}/prompt"
async with self.session.post(url, json=payload,
headers={"Content-Type": "application/json"}) as resp:
resp.raise_for_status()
return await resp.json()
async def _wait_for_execution(
self,
prompt_id: str,
on_progress: Optional[Callable[[str, str], None]] = None,
ws_client_id: str | None = None,
) -> None:
"""
Wait for a queued prompt to finish executing via WebSocket.
Parameters
----------
prompt_id : str
The prompt to wait for.
on_progress : Optional[Callable[[str, str], None]]
Called with ``(node_id, prompt_id)`` for each ``node_executing``
event. Pass ``None`` for Discord commands (no web broadcast
needed).
"""
client_id = ws_client_id if ws_client_id is not None else self.client_id
ws_url = (
f"{self.ws_protocol}://{self.server_address}/ws"
f"?clientId={client_id}"
)
async with websockets.connect(ws_url) as ws:
try:
while True:
out = await ws.recv()
if not isinstance(out, str):
continue
message = json.loads(out)
mtype = message.get("type")
if mtype == "executing":
data = message["data"]
node = data.get("node")
if node:
self.logger.debug("Executing node: %s", node)
if on_progress and data.get("prompt_id") == prompt_id:
try:
on_progress(node, prompt_id)
except Exception:
pass
if data["node"] is None and data.get("prompt_id") == prompt_id:
self.logger.info("Execution complete for prompt %s", prompt_id)
break
elif mtype == "execution_success":
if message.get("data", {}).get("prompt_id") == prompt_id:
self.logger.info("execution_success for prompt %s", prompt_id)
break
elif mtype == "execution_error":
if message.get("data", {}).get("prompt_id") == prompt_id:
error = message.get("data", {}).get("exception_message", "unknown error")
raise RuntimeError(f"ComfyUI execution error: {error}")
except Exception as exc:
self.logger.error("Error during execution wait: %s", exc)
raise
async def _get_history(self, prompt_id: str) -> dict[str, Any]:
"""Retrieve execution history for a given prompt id."""
url = f"{self.protocol}://{self.server_address}/history/{prompt_id}"
async with self.session.get(url) as resp:
resp.raise_for_status()
return await resp.json()
async def _download_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
"""Download an image from ComfyUI and return raw bytes."""
url = f"{self.protocol}://{self.server_address}/view"
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
async with self.session.get(url, params=params) as resp:
resp.raise_for_status()
return await resp.read()
# ------------------------------------------------------------------
# Core generation pipeline
# ------------------------------------------------------------------
async def _general_generate(
self,
workflow: dict[str, Any],
prompt_id: str,
on_progress: Optional[Callable[[str, str], None]] = None,
) -> tuple[List[bytes], List[dict[str, Any]]]:
"""
Queue a workflow, wait for it to execute, then collect outputs.
Returns
-------
tuple[List[bytes], List[dict]]
``(images, videos)`` — images as raw bytes, videos as info dicts.
"""
ws_client_id = str(uuid.uuid4())
await self._queue_prompt(workflow, prompt_id, ws_client_id)
try:
await self._wait_for_execution(prompt_id, on_progress=on_progress, ws_client_id=ws_client_id)
except Exception:
self.logger.error("Execution failed for prompt %s", prompt_id)
return [], []
history = await self._get_history(prompt_id)
if not history:
self.logger.warning("No history for prompt %s", prompt_id)
return [], []
images: List[bytes] = []
videos: List[dict[str, Any]] = []
for node_output in history.get(prompt_id, {}).get("outputs", {}).values():
for image_info in node_output.get("images", []):
name = image_info["filename"]
if name.rsplit(".", 1)[-1].lower() in {"mp4", "webm", "avi"}:
videos.append({
"video_name": name,
"video_subfolder": image_info.get("subfolder", ""),
"video_type": image_info.get("type", "output"),
})
else:
data = await self._download_image(
name, image_info["subfolder"], image_info["type"]
)
images.append(data)
return images, videos
# ------------------------------------------------------------------
# DB persistence helper
# ------------------------------------------------------------------
def _record_to_db(
self,
prompt_id: str,
source: str,
user_label: Optional[str],
overrides: Dict[str, Any],
seed: Optional[int],
images: List[bytes],
videos: List[Dict[str, Any]],
) -> None:
"""Persist generation metadata and file blobs to SQLite. Never raises."""
try:
import generation_db
from pathlib import Path as _Path
gen_id = generation_db.record_generation(
prompt_id, source, user_label, overrides, seed
)
for i, img_data in enumerate(images):
generation_db.record_file(gen_id, f"image_{i:04d}.png", img_data)
if videos and self.output_path:
for vid in videos:
vname = vid.get("video_name", "")
vsub = vid.get("video_subfolder", "")
vpath = (
_Path(self.output_path) / vsub / vname
if vsub
else _Path(self.output_path) / vname
)
try:
generation_db.record_file(gen_id, vname, vpath.read_bytes())
except OSError as exc:
self.logger.warning(
"Could not read video for DB storage: %s: %s", vpath, exc
)
except Exception as exc:
self.logger.warning("Failed to record generation to DB: %s", exc)
# ------------------------------------------------------------------
# Public generation API
# ------------------------------------------------------------------
async def generate_image(
self,
prompt: str,
negative_prompt: Optional[str] = None,
on_progress: Optional[Callable[[str, str], None]] = None,
*,
source: str = "discord",
user_label: Optional[str] = None,
) -> tuple[List[bytes], str]:
"""
Generate images using the current workflow template with a text prompt.
Injects *prompt* (and optionally *negative_prompt*) via the inspector,
plus any currently pinned seed from the state manager. All other
overrides in the state manager are **not** applied here — use
:meth:`generate_image_with_workflow` for the full override set.
Parameters
----------
prompt : str
Positive prompt text.
negative_prompt : Optional[str]
Negative prompt text (optional).
on_progress : Optional[Callable]
Called with ``(node_id, prompt_id)`` for each executing node.
Returns
-------
tuple[List[bytes], str]
``(images, prompt_id)``
"""
template = self.workflow_manager.get_workflow_template()
if not template:
self.logger.warning("No workflow template set; cannot generate.")
return [], ""
overrides: Dict[str, Any] = {"prompt": prompt}
if negative_prompt is not None:
overrides["negative_prompt"] = negative_prompt
# Respect pinned seed from state manager
seed_pin = self.state_manager.get_seed()
if seed_pin is not None:
overrides["seed"] = seed_pin
workflow, applied = self.inspector.inject_overrides(template, overrides)
seed_used = applied.get("seed")
self.last_seed = seed_used
prompt_id = str(uuid.uuid4())
images, _videos = await self._general_generate(workflow, prompt_id, on_progress)
self.last_prompt_id = prompt_id
self.total_generated += 1
self._history.append({
"prompt_id": prompt_id,
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed_used,
})
self._record_to_db(
prompt_id, source, user_label,
{"prompt": prompt, "negative_prompt": negative_prompt},
seed_used, images, [],
)
return images, prompt_id
async def generate_image_with_workflow(
self,
on_progress: Optional[Callable[[str, str], None]] = None,
*,
source: str = "discord",
user_label: Optional[str] = None,
) -> tuple[List[bytes], List[dict[str, Any]], str]:
"""
Generate images/videos from the current workflow applying ALL
overrides stored in the state manager.
Returns
-------
tuple[List[bytes], List[dict], str]
``(images, videos, prompt_id)``
"""
template = self.workflow_manager.get_workflow_template()
prompt_id = str(uuid.uuid4())
if not template:
self.logger.error("No workflow template set")
return [], [], prompt_id
overrides = self.state_manager.get_overrides()
workflow, applied = self.inspector.inject_overrides(template, overrides)
seed_used = applied.get("seed")
self.last_seed = seed_used
images, videos = await self._general_generate(workflow, prompt_id, on_progress)
self.last_prompt_id = prompt_id
self.total_generated += 1
prompt_str = overrides.get("prompt") or ""
neg_str = overrides.get("negative_prompt") or ""
self._history.append({
"prompt_id": prompt_id,
"prompt": (prompt_str[:10] + "") if len(prompt_str) > 10 else prompt_str or None,
"negative_prompt": (neg_str[:10] + "") if len(neg_str) > 10 else neg_str or None,
"seed": seed_used,
})
self._record_to_db(prompt_id, source, user_label, overrides, seed_used, images, videos)
return images, videos, prompt_id
# ------------------------------------------------------------------
# Workflow template management
# ------------------------------------------------------------------
def set_workflow(self, workflow: dict[str, Any]) -> None:
"""Set the workflow template and clear all state overrides."""
self.workflow_manager.set_workflow_template(workflow)
self.state_manager.clear_overrides()
def load_workflow_from_file(self, path: str) -> None:
"""
Load a workflow template from a JSON file.
Also clears state overrides and records the filename in the state
manager for auto-load on restart.
"""
import json as _json
with open(path, "r", encoding="utf-8") as f:
workflow = _json.load(f)
self.workflow_manager.set_workflow_template(workflow)
self.state_manager.clear_overrides()
from pathlib import Path
self.state_manager.set_last_workflow_file(Path(path).name)
def get_workflow_template(self) -> Optional[dict[str, Any]]:
"""Return the current workflow template (or None)."""
return self.workflow_manager.get_workflow_template()
# ------------------------------------------------------------------
# State management convenience wrappers
# ------------------------------------------------------------------
def get_workflow_current_changes(self) -> dict[str, Any]:
"""Return all current overrides (backward-compat)."""
return self.state_manager.get_changes()
def set_workflow_current_changes(self, changes: dict[str, Any]) -> None:
"""Merge override changes (backward-compat)."""
self.state_manager.set_changes(changes, merge=True)
def set_workflow_current_prompt(self, prompt: str) -> None:
self.state_manager.set_prompt(prompt)
def set_workflow_current_negative_prompt(self, negative_prompt: str) -> None:
self.state_manager.set_negative_prompt(negative_prompt)
def set_workflow_current_input_image(self, input_image: str) -> None:
self.state_manager.set_input_image(input_image)
def get_current_workflow_prompt(self) -> Optional[str]:
return self.state_manager.get_prompt()
def get_current_workflow_negative_prompt(self) -> Optional[str]:
return self.state_manager.get_negative_prompt()
def get_current_workflow_input_image(self) -> Optional[str]:
return self.state_manager.get_input_image()
# ------------------------------------------------------------------
# Image upload
# ------------------------------------------------------------------
async def upload_image(
self,
data: bytes,
filename: str,
*,
image_type: str = "input",
overwrite: bool = False,
) -> dict[str, Any]:
"""Upload an image to ComfyUI via the /upload/image endpoint."""
url = f"{self.protocol}://{self.server_address}/upload/image"
form = aiohttp.FormData()
form.add_field("image", data, filename=filename,
content_type="application/octet-stream")
form.add_field("type", image_type)
form.add_field("overwrite", str(overwrite).lower())
async with self.session.post(url, data=form) as resp:
resp.raise_for_status()
try:
return await resp.json()
except aiohttp.ContentTypeError:
return {"status": await resp.text()}
# ------------------------------------------------------------------
# History
# ------------------------------------------------------------------
def get_history(self) -> List[dict]:
"""Return a list of recently generated prompt records (from DB)."""
try:
from generation_db import get_history as db_get_history
return db_get_history(limit=self._history.maxlen or 50)
except Exception:
return list(self._history)
async def fetch_history_images(self, prompt_id: str) -> List[bytes]:
"""Re-download images for a previously generated prompt."""
history = await self._get_history(prompt_id)
images: List[bytes] = []
for node_output in history.get(prompt_id, {}).get("outputs", {}).values():
for image_info in node_output.get("images", []):
data = await self._download_image(
image_info["filename"],
image_info["subfolder"],
image_info["type"],
)
images.append(data)
return images
# ------------------------------------------------------------------
# Server info / queue
# ------------------------------------------------------------------
async def get_system_stats(self) -> Optional[dict[str, Any]]:
"""Fetch ComfyUI system stats (/system_stats)."""
try:
url = f"{self.protocol}://{self.server_address}/system_stats"
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
resp.raise_for_status()
return await resp.json()
except Exception as exc:
self.logger.warning("Failed to fetch system stats: %s", exc)
return None
async def get_comfy_queue(self) -> Optional[dict[str, Any]]:
"""Fetch the ComfyUI queue (/queue)."""
try:
url = f"{self.protocol}://{self.server_address}/queue"
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
resp.raise_for_status()
return await resp.json()
except Exception as exc:
self.logger.warning("Failed to fetch comfy queue: %s", exc)
return None
async def get_queue_depth(self) -> int:
"""Return the total number of pending + running jobs in ComfyUI."""
q = await self.get_comfy_queue()
if q:
return len(q.get("queue_running", [])) + len(q.get("queue_pending", []))
return 0
async def clear_queue(self) -> bool:
"""Clear all pending jobs from the ComfyUI queue."""
try:
url = f"{self.protocol}://{self.server_address}/queue"
async with self.session.post(
url,
json={"clear": True},
headers={"Content-Type": "application/json"},
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
return resp.status in (200, 204)
except Exception as exc:
self.logger.warning("Failed to clear comfy queue: %s", exc)
return False
async def check_connection(self) -> bool:
"""Return True if the ComfyUI server is reachable."""
try:
url = f"{self.protocol}://{self.server_address}/system_stats"
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
return resp.status == 200
except Exception:
return False
async def get_models(self, model_type: str = "checkpoints") -> List[str]:
"""
Fetch available model names from ComfyUI.
Parameters
----------
model_type : str
One of ``"checkpoints"``, ``"loras"``, etc.
"""
try:
url = f"{self.protocol}://{self.server_address}/object_info"
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
resp.raise_for_status()
info = await resp.json()
if model_type == "checkpoints":
node = info.get("CheckpointLoaderSimple", {})
return node.get("input", {}).get("required", {}).get("ckpt_name", [None])[0] or []
elif model_type == "loras":
node = info.get("LoraLoader", {})
return node.get("input", {}).get("required", {}).get("lora_name", [None])[0] or []
return []
except Exception as exc:
self.logger.warning("Failed to fetch models (%s): %s", model_type, exc)
return []
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
async def close(self) -> None:
"""Close the underlying aiohttp session."""
if self._session and not self._session.closed:
await self._session.close()
async def __aenter__(self) -> "ComfyClient":
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()

64
commands/__init__.py Normal file
View File

@@ -0,0 +1,64 @@
"""
commands package
================
Discord bot commands for the ComfyUI bot.
This package contains all command handlers organized by functionality:
- generation: Image/video generation commands (generate, workflow-gen, rerun, cancel)
- workflow: Workflow management commands
- upload: Image upload commands
- history: History viewing and retrieval commands
- workflow_changes: Runtime workflow parameter management (prompt, seed, etc.)
- utility: Quality-of-life commands (ping, status, comfy-stats, comfy-queue, uptime)
- presets: Named workflow preset management
"""
from __future__ import annotations
from config import BotConfig
from .generation import setup_generation_commands
from .input_images import setup_input_image_commands
from .server import setup_server_commands
from .workflow import setup_workflow_commands
from .history import setup_history_commands
from .workflow_changes import setup_workflow_changes_commands
from .utility import setup_utility_commands
from .presets import setup_preset_commands
from .help_command import CustomHelpCommand
def register_all_commands(bot, config: BotConfig):
"""
Register all bot commands.
This function should be called once during bot initialization to set up
all command handlers.
Parameters
----------
bot : commands.Bot
The Discord bot instance.
config : BotConfig
The bot configuration object containing environment settings.
"""
setup_generation_commands(bot, config)
setup_input_image_commands(bot, config)
setup_server_commands(bot, config)
setup_workflow_commands(bot)
setup_history_commands(bot)
setup_workflow_changes_commands(bot)
setup_utility_commands(bot)
setup_preset_commands(bot)
__all__ = [
"register_all_commands",
"setup_generation_commands",
"setup_input_image_commands",
"setup_workflow_commands",
"setup_history_commands",
"setup_workflow_changes_commands",
"setup_utility_commands",
"setup_preset_commands",
"CustomHelpCommand",
]

389
commands/generation.py Normal file
View File

@@ -0,0 +1,389 @@
"""
commands/generation.py
======================
Image and video generation commands for the Discord ComfyUI bot.
Jobs are submitted directly to ComfyUI (no internal SerialJobQueue).
ComfyUI's own queue handles ordering. Each Discord command waits for its
prompt_id to complete via WebSocket and then replies with the result.
"""
from __future__ import annotations
import asyncio
import logging
try:
import aiohttp # type: ignore
except Exception: # pragma: no cover
aiohttp = None # type: ignore
from io import BytesIO
from pathlib import Path
from typing import Optional
import discord
from discord.ext import commands
from config import ARG_PROMPT_KEY, ARG_NEG_PROMPT_KEY, ARG_QUEUE_KEY, MAX_IMAGES_PER_RESPONSE
from discord_utils import require_comfy_client, convert_image_bytes_to_discord_files
from media_uploader import flush_pending
logger = logging.getLogger(__name__)
async def _safe_reply(
ctx: commands.Context,
*,
content: str | None = None,
files: list[discord.File] | None = None,
mention_author: bool = True,
delete_after: float | None = None,
tries: int = 4,
base_delay: float = 1.0,
):
"""Reply to Discord with retries for transient network/Discord errors."""
delay = base_delay
last_exc: Exception | None = None
for attempt in range(1, tries + 1):
try:
return await ctx.reply(
content=content,
files=files or [],
mention_author=mention_author,
delete_after=delete_after,
)
except Exception as exc: # noqa: BLE001
last_exc = exc
transient = False
if isinstance(exc, asyncio.TimeoutError):
transient = True
elif isinstance(exc, OSError) and getattr(exc, "winerror", None) in {
64, 121, 1231, 10053, 10054,
}:
transient = True
if aiohttp is not None:
try:
if isinstance(exc, (aiohttp.ClientOSError, aiohttp.ClientConnectionError)):
transient = True
except Exception:
pass
if isinstance(exc, discord.HTTPException):
status = getattr(exc, "status", None)
if status is None or status >= 500 or status == 429:
transient = True
if (not transient) or attempt == tries:
raise
logger.warning(
"Transient error sending Discord message (attempt %d/%d): %s: %s",
attempt, tries, type(exc).__name__, exc,
)
await asyncio.sleep(delay)
delay *= 2
raise last_exc # type: ignore[misc]
def _seed_line(bot) -> str:
"""Return a formatted seed line if a seed was tracked, else empty string."""
seed = getattr(bot.comfy, "last_seed", None)
return f"\nSeed: `{seed}`" if seed is not None else ""
async def _run_generate(ctx: commands.Context, bot, prompt_text: str, negative_text: Optional[str]):
"""Execute a prompt-based generation and reply with results."""
images, prompt_id = await bot.comfy.generate_image(
prompt_text, negative_text,
source="discord", user_label=ctx.author.display_name,
)
if not images:
await ctx.reply(
"No images were generated. Please try again with a different prompt.",
mention_author=False,
)
return
files = convert_image_bytes_to_discord_files(
images, max_files=MAX_IMAGES_PER_RESPONSE, prefix="generated"
)
response_text = f"Generated {len(images)} image(s). Prompt ID: `{prompt_id}`{_seed_line(bot)}"
await _safe_reply(ctx, content=response_text, files=files, mention_author=True)
asyncio.create_task(flush_pending(
Path(bot.config.comfy_output_path),
bot.config.media_upload_user,
bot.config.media_upload_pass,
))
async def _run_workflow(ctx: commands.Context, bot, config):
"""Execute a workflow-based generation and reply with results."""
logger.info("Executing workflow generation")
await ctx.reply("Executing workflow…", mention_author=False, delete_after=5.0)
images, videos, prompt_id = await bot.comfy.generate_image_with_workflow(
source="discord", user_label=ctx.author.display_name,
)
if not images and not videos:
await ctx.reply(
"No images or videos were generated. Check the workflow and ComfyUI logs.",
mention_author=False,
)
return
seed_info = _seed_line(bot)
if videos:
output_path = config.comfy_output_path
video_file = None
for video_info in videos:
video_name = video_info.get("video_name")
video_subfolder = video_info.get("video_subfolder", "")
if video_name:
video_path = (
Path(output_path) / video_subfolder / video_name
if video_subfolder
else Path(output_path) / video_name
)
try:
video_file = discord.File(
BytesIO(video_path.read_bytes()), filename=video_name
)
break
except Exception as exc:
logger.exception("Failed to read video %s: %s", video_path, exc)
if video_file:
response_text = (
f"Generated {len(images)} image(s) and a video. "
f"Prompt ID: `{prompt_id}`{seed_info}"
)
await _safe_reply(ctx, content=response_text, files=[video_file], mention_author=True)
else:
await ctx.reply(
f"Generated output but failed to read video file. "
f"Prompt ID: `{prompt_id}`{seed_info}",
mention_author=True,
)
else:
files = convert_image_bytes_to_discord_files(
images, max_files=MAX_IMAGES_PER_RESPONSE, prefix="generated"
)
response_text = (
f"Generated {len(images)} image(s) using workflow. "
f"Prompt ID: `{prompt_id}`{seed_info}"
)
await _safe_reply(ctx, content=response_text, files=files, mention_author=True)
asyncio.create_task(flush_pending(
Path(config.comfy_output_path),
config.media_upload_user,
config.media_upload_pass,
))
def setup_generation_commands(bot, config):
"""Register generation commands with the bot."""
@bot.command(name="test", extras={"category": "Generation"})
async def test_command(ctx: commands.Context) -> None:
"""A simple test command to verify the bot is working."""
await ctx.reply(
"The bot is working! Use `ttr!generate` to create images.",
mention_author=False,
)
@bot.command(name="generate", aliases=["gen"], extras={"category": "Generation"})
@require_comfy_client
async def generate(ctx: commands.Context, *, args: str = "") -> None:
"""
Generate images using ComfyUI.
Usage::
ttr!generate prompt:<your prompt> negative_prompt:<your negatives>
The ``prompt:`` keyword is required. ``negative_prompt:`` is optional.
"""
prompt_text: Optional[str] = None
negative_text: Optional[str] = None
if args:
if ARG_PROMPT_KEY in args:
parts = args.split(ARG_PROMPT_KEY, 1)[1]
if ARG_NEG_PROMPT_KEY in parts:
p, n = parts.split(ARG_NEG_PROMPT_KEY, 1)
prompt_text = p.strip()
negative_text = n.strip() or None
else:
prompt_text = parts.strip()
else:
prompt_text = args.strip()
if not prompt_text:
await ctx.reply(
f"Please specify a prompt: `{ARG_PROMPT_KEY}<your prompt>`.",
mention_author=False,
)
return
bot.last_gen = {"mode": "prompt", "prompt": prompt_text, "negative": negative_text}
try:
# Show queue position from ComfyUI before waiting
depth = await bot.comfy.get_queue_depth()
pos = depth + 1
ack = await ctx.reply(
f"Queued ✅ (ComfyUI position: ~{pos})",
mention_author=False,
delete_after=30.0,
)
await _run_generate(ctx, bot, prompt_text, negative_text)
except Exception as exc:
logger.exception("Error generating image")
await ctx.reply(
f"An error occurred: {type(exc).__name__}: {exc}",
mention_author=False,
)
@bot.command(
name="workflow-gen",
aliases=["workflow-generate", "wfg"],
extras={"category": "Generation"},
)
@require_comfy_client
async def generate_workflow_command(ctx: commands.Context, *, args: str = "") -> None:
"""
Generate using the currently loaded workflow template.
Usage::
ttr!workflow-gen
ttr!workflow-gen queue:<number>
"""
bot.last_gen = {"mode": "workflow", "prompt": None, "negative": None}
# Handle batch queue parameter
if ARG_QUEUE_KEY in args:
number_part = args.split(ARG_QUEUE_KEY, 1)[1].strip()
if number_part.isdigit():
queue_times = int(number_part)
if queue_times > 1:
await ctx.reply(
f"Queuing {queue_times} workflow runs…",
mention_author=False,
)
for i in range(queue_times):
try:
depth = await bot.comfy.get_queue_depth()
pos = depth + 1
await ctx.reply(
f"Queued run {i+1}/{queue_times} ✅ (ComfyUI position: ~{pos})",
mention_author=False,
delete_after=30.0,
)
await _run_workflow(ctx, bot, config)
except Exception as exc:
logger.exception("Error on workflow run %d", i + 1)
await ctx.reply(
f"Error on run {i+1}: {type(exc).__name__}: {exc}",
mention_author=False,
)
return
else:
await ctx.reply(
"Please provide a number greater than 1 for queueing multiple runs.",
mention_author=False,
delete_after=30.0,
)
return
else:
await ctx.reply(
f"Invalid queue parameter. Use `{ARG_QUEUE_KEY}<number>`.",
mention_author=False,
delete_after=30.0,
)
return
try:
depth = await bot.comfy.get_queue_depth()
pos = depth + 1
await ctx.reply(
f"Queued ✅ (ComfyUI position: ~{pos})",
mention_author=False,
delete_after=30.0,
)
await _run_workflow(ctx, bot, config)
except Exception as exc:
logger.exception("Error generating with workflow")
await ctx.reply(
f"An error occurred: {type(exc).__name__}: {exc}",
mention_author=False,
)
@bot.command(name="rerun", aliases=["rr"], extras={"category": "Generation"})
@require_comfy_client
async def rerun_command(ctx: commands.Context) -> None:
"""
Re-run the last generation with the same parameters.
Re-submits the most recent ``ttr!generate`` or ``ttr!workflow-gen``
with the same mode and prompt. Current state overrides (seed,
input_image, etc.) are applied at execution time.
"""
last = getattr(bot, "last_gen", None)
if last is None:
await ctx.reply(
"No previous generation to rerun.",
mention_author=False,
)
return
try:
depth = await bot.comfy.get_queue_depth()
pos = depth + 1
await ctx.reply(
f"Rerun queued ✅ (ComfyUI position: ~{pos})",
mention_author=False,
delete_after=30.0,
)
if last["mode"] == "prompt":
await _run_generate(ctx, bot, last["prompt"], last["negative"])
else:
await _run_workflow(ctx, bot, config)
except Exception as exc:
logger.exception("Error queueing rerun")
await ctx.reply(
f"An error occurred: {type(exc).__name__}: {exc}",
mention_author=False,
)
@bot.command(name="cancel", extras={"category": "Generation"})
@require_comfy_client
async def cancel_command(ctx: commands.Context) -> None:
"""
Clear all pending jobs from the ComfyUI queue.
Usage::
ttr!cancel
"""
try:
ok = await bot.comfy.clear_queue()
if ok:
await ctx.reply("ComfyUI queue cleared.", mention_author=False)
else:
await ctx.reply(
"Failed to clear the ComfyUI queue (server may have returned an error).",
mention_author=False,
)
except Exception as exc:
await ctx.reply(f"Error: {exc}", mention_author=False)

134
commands/help_command.py Normal file
View File

@@ -0,0 +1,134 @@
"""
commands/help_command.py
========================
Custom help command for the Discord ComfyUI bot.
Replaces discord.py's default help with a categorised listing that
automatically includes every registered command.
How it works
------------
Each ``@bot.command()`` decorator should carry an ``extras`` dict with a
``"category"`` key:
@bot.command(name="my-command", extras={"category": "Generation"})
async def my_command(ctx):
\"""One-line brief shown in the listing.
Longer description shown in ttr!help my-command.
\"""
The first line of the docstring becomes the brief shown in the main
listing. The full docstring is shown when the user asks for per-command
detail. Commands without a category appear under **Other**.
Usage
-----
ttr!help — list all commands grouped by category
ttr!help <command> — detailed help for a specific command
"""
from __future__ import annotations
from collections import defaultdict
from typing import List, Mapping, Optional
from discord.ext import commands
# Order in which categories appear in the full help listing.
# Any category not listed here appears at the end, sorted alphabetically.
CATEGORY_ORDER = ["Generation", "Workflow", "Upload", "History", "Presets", "Utility"]
def _category_sort_key(name: str) -> tuple:
"""Return a sort key that respects CATEGORY_ORDER, then alphabetical."""
try:
return (CATEGORY_ORDER.index(name), name)
except ValueError:
return (len(CATEGORY_ORDER), name)
class CustomHelpCommand(commands.HelpCommand):
"""
Categorised help command.
Groups commands by the ``"category"`` value in their ``extras`` dict.
Commands that omit this appear under **Other**.
Adding a new command to the help output requires no changes here —
just set ``extras={"category": "..."}`` on the decorator and write a
descriptive docstring.
"""
# ------------------------------------------------------------------
# Main listing — ttr!help
# ------------------------------------------------------------------
async def send_bot_help(
self,
mapping: Mapping[Optional[commands.Cog], List[commands.Command]],
) -> None:
"""Send the full command listing grouped by category."""
# Collect all visible commands across every cog / None bucket
all_commands: List[commands.Command] = []
for cmds in mapping.values():
filtered = await self.filter_commands(cmds)
all_commands.extend(filtered)
# Group by category
categories: dict[str, list[commands.Command]] = defaultdict(list)
for cmd in all_commands:
cat = cmd.extras.get("category", "Other")
categories[cat].append(cmd)
prefix = self.context.prefix
lines: list[str] = [f"**Commands** — prefix: `{prefix}`\n"]
for cat in sorted(categories.keys(), key=_category_sort_key):
cmds = sorted(categories[cat], key=lambda c: c.name)
lines.append(f"**{cat}**")
for cmd in cmds:
aliases = (
f" ({', '.join(cmd.aliases)})" if cmd.aliases else ""
)
brief = cmd.short_doc or "No description."
lines.append(f" `{cmd.name}`{aliases}{brief}")
lines.append("")
lines.append(
f"Use `{prefix}help <command>` for details on a specific command."
)
await self.get_destination().send("\n".join(lines))
# ------------------------------------------------------------------
# Per-command detail — ttr!help <command>
# ------------------------------------------------------------------
async def send_command_help(self, command: commands.Command) -> None:
"""Send detailed help for a single command."""
prefix = self.context.prefix
header = f"`{prefix}{command.name}`"
if command.aliases:
alias_list = ", ".join(f"`{a}`" for a in command.aliases)
header += f" (aliases: {alias_list})"
category = command.extras.get("category", "Other")
lines: list[str] = [header, f"Category: **{category}**", ""]
if command.help:
lines.append(command.help.strip())
else:
lines.append("No description available.")
await self.get_destination().send("\n".join(lines))
# ------------------------------------------------------------------
# Error — unknown command name
# ------------------------------------------------------------------
async def send_error_message(self, error: str) -> None:
"""Forward the error text to the channel."""
await self.get_destination().send(error)

169
commands/history.py Normal file
View File

@@ -0,0 +1,169 @@
"""
commands/history.py
===================
History management commands for the Discord ComfyUI bot.
This module contains commands for viewing and retrieving past generation
results from the bot's history.
"""
from __future__ import annotations
import logging
from io import BytesIO
from typing import Optional
import discord
from discord.ext import commands
from config import MAX_IMAGES_PER_RESPONSE
from discord_utils import require_comfy_client, truncate_text, convert_image_bytes_to_discord_files
logger = logging.getLogger(__name__)
def setup_history_commands(bot):
"""
Register history management commands with the bot.
Parameters
----------
bot : commands.Bot
The Discord bot instance.
"""
@bot.command(name="history", extras={"category": "History"})
@require_comfy_client
async def history_command(ctx: commands.Context) -> None:
"""
Show a list of recently generated prompts.
The bot keeps a rolling history of the last few generations. Each
entry lists the prompt id along with the positive and negative
prompt texts. You can retrieve the images from a previous
generation with the ``ttr!gethistory <prompt_id>`` command.
"""
hist = bot.comfy.get_history()
if not hist:
await ctx.reply(
"No history available yet. Generate something first!",
mention_author=False,
)
return
# Build a human readable list
lines = ["Here are the most recent generations (oldest first):"]
for entry in hist:
pid = entry.get("prompt_id", "unknown")
prompt = entry.get("prompt") or ""
neg = entry.get("negative_prompt") or ""
# Truncate long prompts for readability
lines.append(
f"• ID: {pid} | prompt: '{truncate_text(prompt, 60)}' | negative: '{truncate_text(neg, 60)}'"
)
await ctx.reply("\n".join(lines), mention_author=False)
@bot.command(name="get-history", aliases=["gethistory", "gh"], extras={"category": "History"})
@require_comfy_client
async def get_history_command(ctx: commands.Context, *, arg: str = "") -> None:
"""
Retrieve images from a previous generation, or search history by keyword.
Usage:
ttr!gethistory <prompt_id_or_index>
ttr!gethistory search:<keyword>
Provide either the prompt id returned in the generation response
(shown in `ttr!history`) or the 1based index into the history
list. The bot will fetch the images associated with that
generation and resend them. If no images are found, you will be
notified.
Use ``search:<keyword>`` to filter history by prompt text, checkpoint
name, seed value, or any other override field.
"""
if not arg:
await ctx.reply(
"Please provide a prompt id, history index, or `search:<keyword>`. See `ttr!history` for a list.",
mention_author=False,
)
return
# Handle search:<keyword>
lower_arg = arg.lower()
if lower_arg.startswith("search:"):
keyword = arg[len("search:"):].strip()
if not keyword:
await ctx.reply("Please provide a keyword after `search:`.", mention_author=False)
return
from generation_db import search_history_for_user, get_history as db_get_history
# Use get_history for Discord since Discord bot doesn't have per-user context like the web UI
hist = db_get_history(limit=50)
matches = [
e for e in hist
if keyword.lower() in str(e.get("overrides", {})).lower()
]
if not matches:
await ctx.reply(f"No history entries matching `{keyword}`.", mention_author=False)
return
lines = [f"**History matching `{keyword}`** ({len(matches)} result(s))"]
for entry in matches[:10]:
pid = entry.get("prompt_id", "unknown")
overrides = entry.get("overrides") or {}
prompt = str(overrides.get("prompt") or "")
lines.append(
f"• `{pid[:12]}…` | {truncate_text(prompt, 60) if prompt else '(no prompt)'}"
)
if len(matches) > 10:
lines.append(f"_(showing first 10 of {len(matches)})_")
await ctx.reply("\n".join(lines), mention_author=False)
return
# Determine whether arg refers to an index or an id
target_id: Optional[str] = None
hist = bot.comfy.get_history()
# If arg is a digit, interpret as 1based index
if arg.isdigit():
idx = int(arg) - 1
if idx < 0 or idx >= len(hist):
await ctx.reply(
f"Index out of range. There are {len(hist)} entries in history.",
mention_author=False,
)
return
target_id = hist[idx]["prompt_id"]
else:
# Otherwise treat as an explicit prompt id
target_id = arg.strip()
try:
images = await bot.comfy.fetch_history_images(target_id)
if not images:
await ctx.reply(
f"No images found for prompt id `{target_id}`.",
mention_author=False,
)
return
files = []
for idx, img_bytes in enumerate(images):
if idx >= MAX_IMAGES_PER_RESPONSE:
break
file_obj = BytesIO(img_bytes)
file_obj.seek(0)
files.append(discord.File(file_obj, filename=f"history_{target_id}_{idx+1}.png"))
await ctx.reply(
content=f"Here are the images for prompt id `{target_id}`:",
files=files,
mention_author=False,
)
except Exception as exc:
logger.exception("Failed to fetch history for %s", target_id)
await ctx.reply(
f"An error occurred: {type(exc).__name__}: {exc}",
mention_author=False,
)

178
commands/input_images.py Normal file
View File

@@ -0,0 +1,178 @@
"""
commands/input_images.py
========================
Channel-backed input image management.
Images uploaded to the designated `comfy-input` channel get a persistent
"✅ Set as input" button posted by the bot — one reply per attachment so
every image in a multi-image message is independently selectable.
Persistent views survive bot restarts: on_ready re-registers every view
stored in the SQLite database.
"""
from __future__ import annotations
import io
import logging
from pathlib import Path
import discord
from discord.ext import commands
from image_utils import compress_to_discord_limit
from input_image_db import (
activate_image_for_slot,
get_all_images,
upsert_image,
)
logger = logging.getLogger(__name__)
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp"}
class PersistentSetInputView(discord.ui.View):
"""
A persistent view that survives bot restarts.
One instance is created per DB row (i.e. per attachment).
The button's custom_id encodes the row id so the callback can look
up the exact filename to download.
"""
def __init__(self, bot, config, row_id: int):
super().__init__(timeout=None)
self._bot = bot
self._config = config
self._row_id = row_id
btn = discord.ui.Button(
label="✅ Set as input",
style=discord.ButtonStyle.success,
custom_id=f"set_input:{row_id}",
)
btn.callback = self._set_callback
self.add_item(btn)
async def _set_callback(self, interaction: discord.Interaction) -> None:
await interaction.response.defer(ephemeral=True)
try:
filename = activate_image_for_slot(
self._row_id, "input_image", self._config.comfy_input_path
)
self._bot.comfy.state_manager.set_override("input_image", filename)
await interaction.followup.send(
f"✅ Input image set to `{filename}`", ephemeral=True
)
except Exception as exc:
logger.exception("set_input button failed for row %s", self._row_id)
await interaction.followup.send(f"❌ Error: {exc}", ephemeral=True)
async def _register_attachment(bot, config, message: discord.Message, attachment: discord.Attachment) -> None:
"""Post a reply with the image preview, a Set-as-input button, and record it in the DB."""
logger.info("[_register_attachment] Start")
original_data = await attachment.read()
original_filename = attachment.filename
logger.info("[_register_attachment] Reading attachment")
# Compress only for the Discord re-send (8 MiB bot limit)
send_data, send_filename = compress_to_discord_limit(original_data, original_filename)
file = discord.File(io.BytesIO(send_data), filename=send_filename)
reply = await message.channel.send(f"`{original_filename}`", file=file)
# Store original quality bytes in DB
row_id = upsert_image(message.id, reply.id, message.channel.id, original_filename, image_data=original_data)
view = PersistentSetInputView(bot, config, row_id)
bot.add_view(view, message_id=reply.id)
logger.info("[_register_attachment] Done")
await reply.edit(view=view)
def setup_input_image_commands(bot, config=None):
"""Register input image commands and the on_message listener."""
@bot.listen("on_message")
async def _on_input_channel_message(message: discord.Message) -> None:
"""Watch the comfy-input channel and attach a Set-as-input button to every image upload."""
if config is None:
logger.warning("[_on_input_channel_message] Config is none")
return
if message.channel.id != config.comfy_input_channel_id:
return
if message.author.bot:
return
image_attachments = [
a for a in message.attachments
if Path(a.filename).suffix.lower() in IMAGE_EXTENSIONS
]
if not image_attachments:
logger.info("[_on_input_channel_message] No image attachments")
return
for attachment in image_attachments:
await _register_attachment(bot, config, message, attachment)
try:
await message.delete()
except discord.Forbidden:
logger.warning("Missing manage_messages permission to delete message %s", message.id)
except Exception as exc:
logger.warning("Could not delete message %s: %s", message.id, exc)
@bot.command(
name="sync-inputs",
aliases=["si"],
extras={"category": "Files"},
help="Scan the comfy-input channel and add 'Set as input' buttons to any untracked images.",
)
async def sync_inputs_command(ctx: commands.Context) -> None:
"""Backfill Set-as-input buttons for images uploaded while the bot was offline."""
if config is None:
await ctx.reply("Bot config is not available.", mention_author=False)
return
channel = bot.get_channel(config.comfy_input_channel_id)
if channel is None:
try:
channel = await bot.fetch_channel(config.comfy_input_channel_id)
except Exception as exc:
await ctx.reply(f"❌ Could not access input channel: {exc}", mention_author=False)
return
# Track existing records as (message_id, filename) pairs
existing = {(row["original_message_id"], row["filename"]) for row in get_all_images()}
new_count = 0
async for message in channel.history(limit=None):
if message.author.bot:
continue
had_new = False
for attachment in message.attachments:
if Path(attachment.filename).suffix.lower() not in IMAGE_EXTENSIONS:
continue
if (message.id, attachment.filename) in existing:
continue
await _register_attachment(bot, config, message, attachment)
existing.add((message.id, attachment.filename))
new_count += 1
had_new = True
if had_new:
try:
await message.delete()
except Exception as exc:
logger.warning("sync-inputs: could not delete message %s: %s", message.id, exc)
already = len(get_all_images()) - new_count
await ctx.reply(
f"Synced {new_count} new image(s). {already} already known.",
mention_author=False,
)

370
commands/presets.py Normal file
View File

@@ -0,0 +1,370 @@
"""
commands/presets.py
===================
Named workflow preset commands for the Discord ComfyUI bot.
A preset is a saved snapshot of the current workflow template and runtime
state (prompt, negative_prompt, input_image, seed). Presets make it easy
to switch between different setups (e.g. "portrait", "landscape", "anime")
with a single command.
All sub-commands are accessed through the single ``ttr!preset`` command:
ttr!preset save <name> [description:<text>] — capture current workflow + state
ttr!preset load <name> — restore workflow + state
ttr!preset list — list all saved presets
ttr!preset view <name> — show preset details
ttr!preset delete <name> — permanently remove a preset
ttr!preset save-last <name> [description:<text>] — save last generation as preset
"""
from __future__ import annotations
import logging
from discord.ext import commands
from discord_utils import require_comfy_client
from preset_manager import PresetManager
logger = logging.getLogger(__name__)
def _parse_name_and_description(args: str) -> tuple[str, str | None]:
"""
Split ``<name> [description:<text>]`` into (name, description).
The name is the first whitespace-delimited token. Everything after
``description:`` (case-insensitive) in the remaining text is the
description. Returns (name, None) if no description keyword is found.
"""
parts = args.strip().split(maxsplit=1)
name = parts[0] if parts else ""
description: str | None = None
if len(parts) > 1:
rest = parts[1]
lower = rest.lower()
idx = lower.find("description:")
if idx >= 0:
description = rest[idx + len("description:"):].strip() or None
return name, description
def setup_preset_commands(bot):
"""
Register preset commands with the bot.
Parameters
----------
bot : commands.Bot
The Discord bot instance.
"""
preset_manager = PresetManager()
@bot.command(name="preset", extras={"category": "Presets"})
@require_comfy_client
async def preset_command(ctx: commands.Context, *, args: str = "") -> None:
"""
Save, load, list, view, or delete named workflow presets.
A preset captures the current workflow template and all runtime
state changes (prompt, negative_prompt, input_image, seed) under a
short name. Load it later to restore everything in one step.
Usage:
ttr!preset save <name> [description:<text>] — save current workflow + state
ttr!preset load <name> — restore workflow + state
ttr!preset list — list all saved presets
ttr!preset view <name> — show preset details
ttr!preset delete <name> — permanently delete a preset
ttr!preset save-last <name> [description:<text>] — save last generation as preset
Names may only contain letters, digits, hyphens, and underscores.
Examples:
ttr!preset save portrait description:studio lighting style
ttr!preset load portrait
ttr!preset list
ttr!preset view portrait
ttr!preset delete portrait
ttr!preset save-last my-last
"""
parts = args.strip().split(maxsplit=1)
subcommand = parts[0].lower() if parts else ""
rest = parts[1].strip() if len(parts) > 1 else ""
if subcommand == "save":
name, description = _parse_name_and_description(rest)
await _preset_save(ctx, bot, preset_manager, name, description)
elif subcommand == "load":
await _preset_load(ctx, bot, preset_manager, rest.split()[0] if rest.split() else "")
elif subcommand == "list":
await _preset_list(ctx, preset_manager)
elif subcommand == "view":
await _preset_view(ctx, preset_manager, rest.split()[0] if rest.split() else "")
elif subcommand == "delete":
await _preset_delete(ctx, preset_manager, rest.split()[0] if rest.split() else "")
elif subcommand == "save-last":
name, description = _parse_name_and_description(rest)
await _preset_save_last(ctx, preset_manager, name, description)
else:
await ctx.reply(
"Usage: `ttr!preset <save|load|list|view|delete|save-last> [name]`\n"
"Run `ttr!help preset` for full details.",
mention_author=False,
)
async def _preset_save(
ctx: commands.Context, bot, preset_manager: PresetManager, name: str,
description: str | None = None,
) -> None:
"""Handle ttr!preset save <name> [description:<text>]."""
if not name:
await ctx.reply(
"Please provide a name. Example: `ttr!preset save portrait`",
mention_author=False,
)
return
if not PresetManager.is_valid_name(name):
await ctx.reply(
"Invalid name. Use only letters, digits, hyphens, and underscores (max 64 chars).",
mention_author=False,
)
return
try:
workflow_template = bot.comfy.get_workflow_template()
state = bot.comfy.get_workflow_current_changes()
preset_manager.save(name, workflow_template, state, description=description)
# Build a summary of what was saved
has_workflow = workflow_template is not None
state_parts = []
if state.get("prompt"):
state_parts.append("prompt")
if state.get("negative_prompt"):
state_parts.append("negative_prompt")
if state.get("input_image"):
state_parts.append("input_image")
if state.get("seed") is not None:
state_parts.append(f"seed={state['seed']}")
summary_parts = []
if has_workflow:
summary_parts.append("workflow template")
summary_parts.extend(state_parts)
summary = ", ".join(summary_parts) if summary_parts else "empty state"
desc_note = f"\n> {description}" if description else ""
await ctx.reply(
f"Preset **{name}** saved ({summary}).{desc_note}",
mention_author=False,
)
except Exception as exc:
logger.exception("Failed to save preset '%s'", name)
await ctx.reply(
f"Failed to save preset: {type(exc).__name__}: {exc}",
mention_author=False,
)
async def _preset_load(
ctx: commands.Context, bot, preset_manager: PresetManager, name: str
) -> None:
"""Handle ttr!preset load <name>."""
if not name:
await ctx.reply(
"Please provide a name. Example: `ttr!preset load portrait`",
mention_author=False,
)
return
data = preset_manager.load(name)
if data is None:
presets = preset_manager.list_presets()
hint = f" Available: {', '.join(presets)}" if presets else " No presets saved yet."
await ctx.reply(
f"Preset **{name}** not found.{hint}",
mention_author=False,
)
return
try:
restored: list[str] = []
# Restore workflow template if present
workflow = data.get("workflow")
if workflow is not None:
bot.comfy.set_workflow(workflow)
restored.append("workflow template")
# Restore state changes
state = data.get("state", {})
if state:
bot.comfy.set_workflow_current_changes(state)
if state.get("prompt"):
restored.append("prompt")
if state.get("negative_prompt"):
restored.append("negative_prompt")
if state.get("input_image"):
restored.append("input_image")
if state.get("seed") is not None:
restored.append(f"seed={state['seed']}")
summary = ", ".join(restored) if restored else "nothing (preset was empty)"
description = data.get("description")
desc_note = f"\n> {description}" if description else ""
await ctx.reply(
f"Preset **{name}** loaded ({summary}).{desc_note}",
mention_author=False,
)
except Exception as exc:
logger.exception("Failed to load preset '%s'", name)
await ctx.reply(
f"Failed to load preset: {type(exc).__name__}: {exc}",
mention_author=False,
)
async def _preset_view(
ctx: commands.Context, preset_manager: PresetManager, name: str
) -> None:
"""Handle ttr!preset view <name>."""
if not name:
await ctx.reply(
"Please provide a name. Example: `ttr!preset view portrait`",
mention_author=False,
)
return
data = preset_manager.load(name)
if data is None:
await ctx.reply(f"Preset **{name}** not found.", mention_author=False)
return
lines = [f"**Preset: {name}**"]
if data.get("description"):
lines.append(f"> {data['description']}")
if data.get("owner"):
lines.append(f"Owner: {data['owner']}")
state = data.get("state", {})
if state.get("prompt"):
# Truncate long prompts
p = str(state["prompt"])
if len(p) > 200:
p = p[:197] + ""
lines.append(f"**Prompt:** {p}")
if state.get("negative_prompt"):
np = str(state["negative_prompt"])
if len(np) > 100:
np = np[:97] + ""
lines.append(f"**Negative:** {np}")
if state.get("seed") is not None:
seed_note = " (random)" if state["seed"] == -1 else ""
lines.append(f"**Seed:** {state['seed']}{seed_note}")
other = {k: v for k, v in state.items() if k not in ("prompt", "negative_prompt", "seed", "input_image")}
if other:
other_str = ", ".join(f"{k}={v}" for k, v in other.items())
lines.append(f"**Other:** {other_str[:200]}")
if data.get("workflow") is not None:
lines.append("_(includes workflow template)_")
else:
lines.append("_(no workflow template — load separately)_")
await ctx.reply("\n".join(lines), mention_author=False)
async def _preset_list(ctx: commands.Context, preset_manager: PresetManager) -> None:
"""Handle ttr!preset list."""
presets = preset_manager.list_preset_details()
if not presets:
await ctx.reply(
"No presets saved yet. Use `ttr!preset save <name>` to create one.",
mention_author=False,
)
return
lines = [f"**Saved presets** ({len(presets)})"]
for p in presets:
entry = f"{p['name']}"
if p.get("description"):
entry += f"{p['description']}"
lines.append(entry)
lines.append("\nUse `ttr!preset load <name>` to restore one.")
await ctx.reply("\n".join(lines), mention_author=False)
async def _preset_delete(
ctx: commands.Context, preset_manager: PresetManager, name: str
) -> None:
"""Handle ttr!preset delete <name>."""
if not name:
await ctx.reply(
"Please provide a name. Example: `ttr!preset delete portrait`",
mention_author=False,
)
return
deleted = preset_manager.delete(name)
if deleted:
await ctx.reply(f"Preset **{name}** deleted.", mention_author=False)
else:
await ctx.reply(
f"Preset **{name}** not found.",
mention_author=False,
)
async def _preset_save_last(
ctx: commands.Context, preset_manager: PresetManager, name: str,
description: str | None = None,
) -> None:
"""Handle ttr!preset save-last <name> [description:<text>]."""
if not name:
await ctx.reply(
"Please provide a name. Example: `ttr!preset save-last my-last`",
mention_author=False,
)
return
if not PresetManager.is_valid_name(name):
await ctx.reply(
"Invalid name. Use only letters, digits, hyphens, and underscores (max 64 chars).",
mention_author=False,
)
return
from generation_db import get_history as db_get_history
history = db_get_history(limit=1)
if not history:
await ctx.reply(
"No generation history found. Generate something first!",
mention_author=False,
)
return
last = history[0]
overrides = last.get("overrides") or {}
try:
preset_manager.save(name, None, overrides, description=description)
desc_note = f"\n> {description}" if description else ""
await ctx.reply(
f"Preset **{name}** saved from last generation.{desc_note}\n"
"Note: workflow template not included — load it separately before generating.",
mention_author=False,
)
except ValueError as exc:
await ctx.reply(str(exc), mention_author=False)
except Exception as exc:
logger.exception("Failed to save preset '%s' from history", name)
await ctx.reply(
f"Failed to save preset: {type(exc).__name__}: {exc}",
mention_author=False,
)

484
commands/server.py Normal file
View File

@@ -0,0 +1,484 @@
"""
commands/server.py
==================
ComfyUI server lifecycle management via NSSM Windows service.
On bot startup, `autostart_comfy()` runs as a background task:
1. If the service does not exist, it is installed automatically.
2. If the service exists but ComfyUI is not responding, it is started.
NSSM handles:
- Background process management (no console window)
- Stdout / stderr capture to rotating log files
- Complete isolation from the bot's own NSSM service
Commands:
ttr!server start — start the service
ttr!server stop — stop the service
ttr!server restart — restart the service
ttr!server status — NSSM service state + HTTP reachability
ttr!server install — (re)install / reconfigure the NSSM service
ttr!server uninstall — remove the service from Windows
Requires:
- nssm.exe in PATH
- The bot service account must have permission to manage Windows services
(Local System or a user with SeServiceLogonRight works)
"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
import aiohttp
from discord.ext import commands
logger = logging.getLogger(__name__)
_POLL_INTERVAL = 5 # seconds between HTTP up-checks
_MAX_ATTEMPTS = 24 # 24 × 5s = 120s max wait
# Public — imported by status_monitor for emoji rendering
STATUS_EMOJI: dict[str, str] = {
"SERVICE_RUNNING": "🟢",
"SERVICE_STOPPED": "🔴",
"SERVICE_PAUSED": "🟡",
"SERVICE_START_PENDING": "",
"SERVICE_STOP_PENDING": "",
"SERVICE_PAUSE_PENDING": "",
"SERVICE_CONTINUE_PENDING": "",
}
# ---------------------------------------------------------------------------
# Low-level subprocess helpers
# ---------------------------------------------------------------------------
async def _nssm(*args: str) -> tuple[int, str]:
"""Run `nssm <args>` and return (returncode, stdout)."""
try:
proc = await asyncio.create_subprocess_exec(
"nssm", *args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=30)
return proc.returncode, stdout.decode(errors="replace").strip()
except FileNotFoundError:
return -1, "nssm not found — is it installed and in PATH?"
except asyncio.TimeoutError:
return -1, "nssm command timed out."
except Exception as exc:
return -1, str(exc)
async def _get_service_pid(service_name: str) -> int:
"""Return the PID of the process backing *service_name*, or 0 if unavailable."""
rc, out = await _nssm("getpid", service_name)
if rc != 0:
return 0
try:
return int(out.strip())
except ValueError:
return 0
async def _kill_service_process(service_name: str) -> None:
"""
Forcefully kill the process backing *service_name*.
NSSM does not have a `kill` subcommand. Instead we retrieve the PID
via `nssm getpid` and then use `taskkill /F /PID`. Safe to call when
the service is already stopped (no-op if PID is 0).
"""
pid = await _get_service_pid(service_name)
if not pid:
return
try:
proc = await asyncio.create_subprocess_exec(
"taskkill", "/F", "/PID", str(pid),
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await asyncio.wait_for(proc.communicate(), timeout=10)
logger.debug("taskkill /F /PID %d sent for service '%s'", pid, service_name)
except Exception as exc:
logger.warning("taskkill failed for PID %d (%s): %s", pid, service_name, exc)
async def _is_comfy_up(server_address: str, timeout: float = 3.0) -> bool:
"""Return True if the ComfyUI HTTP endpoint is responding."""
url = f"http://{server_address}/system_stats"
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=timeout)) as resp:
return resp.status == 200
except Exception:
return False
async def _service_exists(service_name: str) -> bool:
"""Return True if the Windows service is installed (running or stopped)."""
try:
proc = await asyncio.create_subprocess_exec(
"sc", "query", service_name,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await proc.communicate()
return proc.returncode == 0
except Exception:
return False
# ---------------------------------------------------------------------------
# Public API — used by status_monitor and other modules
# ---------------------------------------------------------------------------
async def get_service_state(service_name: str) -> str:
"""
Return the NSSM service state string for *service_name*.
Returns one of the SERVICE_* keys in STATUS_EMOJI on success, or
"error" / "timeout" / "unknown" on failure. Intended for use by
the status dashboard — callers should not raise on these sentinel values.
"""
try:
rc, out = await asyncio.wait_for(_nssm("status", service_name), timeout=5.0)
if rc == -1:
return "error"
return out.strip() or "unknown"
except asyncio.TimeoutError:
return "timeout"
except Exception:
return "error"
# ---------------------------------------------------------------------------
# Service installation
# ---------------------------------------------------------------------------
async def _install_service(config) -> tuple[bool, str]:
"""
Install the ComfyUI NSSM service with log capture and rotation.
We install directly via python.exe (not the .bat file) to avoid the
"Terminate batch job (Y/N)?" prompt that can cause NSSM to hang on STOP.
Safe to call even if the service already exists — it will be removed first.
Returns (success, message).
"""
name = config.comfy_service_name
start_bat = Path(config.comfy_start_bat)
log_dir = Path(config.comfy_log_dir)
log_file = str(log_dir / "comfyui.log")
max_bytes = str(config.comfy_log_max_mb * 1024 * 1024)
# Derive portable paths from the .bat location (ComfyUI_windows_portable root):
# <root>/run_nvidia_gpu.bat
# <root>/python_embeded/python.exe
# <root>/ComfyUI/main.py
portable_root = start_bat.parent
python_exe = portable_root / "python_embeded" / "python.exe"
main_py = portable_root / "ComfyUI" / "main.py"
if not start_bat.exists():
return False, f"Start bat not found (used to derive paths): `{start_bat}`"
if not python_exe.exists():
return False, f"Portable python not found: `{python_exe}`"
if not main_py.exists():
return False, f"ComfyUI main.py not found: `{main_py}`"
log_dir.mkdir(parents=True, exist_ok=True)
# Optional extra args from config (accepts string or list/tuple)
extra_args: list[str] = []
extra = getattr(config, "comfy_extra_args", None)
try:
if isinstance(extra, (list, tuple)):
extra_args = [str(x) for x in extra if str(x).strip()]
elif isinstance(extra, str) and extra.strip():
import shlex
extra_args = shlex.split(extra)
except Exception:
extra_args = [] # ignore parse errors rather than aborting install
# Remove any existing service cleanly before reinstalling
if await _service_exists(name):
await _nssm("stop", name)
await _kill_service_process(name) # force-kill if stuck in STOP_PENDING
rc, out = await _nssm("remove", name, "confirm")
if rc != 0:
return False, f"Could not remove existing service: {out}"
# nssm install <name> <python.exe> -s <main.py> --windows-standalone-build [extra]
steps: list[tuple[str, ...]] = [
("install", name, str(python_exe), "-s", str(main_py), "--windows-standalone-build", *extra_args),
("set", name, "AppDirectory", str(portable_root)),
("set", name, "DisplayName", "ComfyUI Server"),
("set", name, "AppStdout", log_file),
("set", name, "AppStderr", log_file),
("set", name, "AppRotateFiles", "1"),
("set", name, "AppRotateBytes", max_bytes),
("set", name, "AppRotateOnline", "1"),
("set", name, "Start", "SERVICE_DEMAND_START"),
# Stop behavior — prevent NSSM from hanging indefinitely
("set", name, "AppKillProcessTree", "1"),
("set", name, "AppStopMethodConsole", "1500"),
("set", name, "AppStopMethodWindow", "1500"),
("set", name, "AppStopMethodThreads", "1500"),
]
for step in steps:
rc, out = await _nssm(*step)
if rc != 0:
return False, f"`nssm {' '.join(step[:3])}` failed: {out}"
return True, f"Service `{name}` installed. Log: `{log_file}`"
# ---------------------------------------------------------------------------
# Autostart (called from bot.py on_ready)
# ---------------------------------------------------------------------------
async def autostart_comfy(config) -> None:
"""
Ensure ComfyUI is running when the bot starts.
1. Install the NSSM service if it is missing.
2. Start the service if ComfyUI is not already responding.
Does nothing if config.comfy_autostart is False.
"""
if not getattr(config, "comfy_autostart", True):
return
if not await _service_exists(config.comfy_service_name):
logger.info("NSSM service '%s' not found — installing", config.comfy_service_name)
ok, msg = await _install_service(config)
if not ok:
logger.error("Failed to install ComfyUI service: %s", msg)
return
logger.info("ComfyUI service installed: %s", msg)
if await _is_comfy_up(config.comfy_server):
logger.info("ComfyUI already running at %s", config.comfy_server)
return
logger.info("Starting NSSM service '%s'", config.comfy_service_name)
rc, out = await _nssm("start", config.comfy_service_name)
if rc != 0:
logger.warning("nssm start returned %d: %s", rc, out)
return
for attempt in range(_MAX_ATTEMPTS):
await asyncio.sleep(_POLL_INTERVAL)
if await _is_comfy_up(config.comfy_server):
logger.info("ComfyUI is up after ~%ds", (attempt + 1) * _POLL_INTERVAL)
return
logger.warning(
"ComfyUI did not respond within %ds after service start",
_MAX_ATTEMPTS * _POLL_INTERVAL,
)
# ---------------------------------------------------------------------------
# Discord commands
# ---------------------------------------------------------------------------
def setup_server_commands(bot, config=None):
"""Register ComfyUI server management commands."""
def _no_config(ctx):
"""Reply and return True when config is missing (guards every subcommand)."""
return config is None
@bot.group(name="server", invoke_without_command=True, extras={"category": "Server"})
async def server_group(ctx: commands.Context) -> None:
"""ComfyUI server management. Subcommands: start, stop, restart, status, install, uninstall."""
await ctx.send_help(ctx.command)
@server_group.command(name="start")
async def server_start(ctx: commands.Context) -> None:
"""Start the ComfyUI service."""
if config is None:
await ctx.reply("Bot config not available.", mention_author=False)
return
if await _is_comfy_up(config.comfy_server):
await ctx.reply("✅ ComfyUI is already running.", mention_author=False)
return
msg = await ctx.reply(
f"⏳ Starting service `{config.comfy_service_name}`…", mention_author=False
)
rc, out = await _nssm("start", config.comfy_service_name)
if rc != 0:
await msg.edit(content=f"❌ `{out}`")
return
await msg.edit(content="⏳ Waiting for ComfyUI to respond…")
for attempt in range(_MAX_ATTEMPTS):
await asyncio.sleep(_POLL_INTERVAL)
if await _is_comfy_up(config.comfy_server):
await msg.edit(
content=f"✅ ComfyUI is up! (took ~{(attempt + 1) * _POLL_INTERVAL}s)"
)
return
await msg.edit(content="⚠️ Service started but ComfyUI did not respond within 120 seconds.")
@server_group.command(name="stop")
async def server_stop(ctx: commands.Context) -> None:
"""Stop the ComfyUI service (force-kills if graceful stop fails)."""
if config is None:
await ctx.reply("Bot config not available.", mention_author=False)
return
msg = await ctx.reply(
f"⏳ Stopping service `{config.comfy_service_name}`…", mention_author=False
)
rc, out = await _nssm("stop", config.comfy_service_name)
if rc == 0:
await msg.edit(content="✅ ComfyUI service stopped.")
return
# Graceful stop failed (timed out or error) — force-kill the process.
await msg.edit(content="⏳ Graceful stop failed — force-killing process…")
await _kill_service_process(config.comfy_service_name)
await asyncio.sleep(2)
state = await get_service_state(config.comfy_service_name)
if state == "SERVICE_STOPPED":
await msg.edit(content="✅ ComfyUI service force-killed and stopped.")
else:
await msg.edit(
content=f"⚠️ Force-kill sent but service state is `{state}`. "
f"Use `ttr!server kill` to try again."
)
@server_group.command(name="kill")
async def server_kill(ctx: commands.Context) -> None:
"""Force-kill the ComfyUI process when it is stuck in STOPPING/STOP_PENDING."""
if config is None:
await ctx.reply("Bot config not available.", mention_author=False)
return
msg = await ctx.reply(
f"⏳ Force-killing `{config.comfy_service_name}` process…", mention_author=False
)
await _kill_service_process(config.comfy_service_name)
await asyncio.sleep(2)
state = await get_service_state(config.comfy_service_name)
emoji = STATUS_EMOJI.get(state, "")
await msg.edit(
content=f"💀 taskkill sent. Service state is now {emoji} `{state}`."
)
@server_group.command(name="restart")
async def server_restart(ctx: commands.Context) -> None:
"""Restart the ComfyUI service (force-kills if graceful stop fails)."""
if config is None:
await ctx.reply("Bot config not available.", mention_author=False)
return
msg = await ctx.reply(
f"⏳ Stopping `{config.comfy_service_name}` for restart…", mention_author=False
)
# Step 1: graceful stop.
rc, out = await _nssm("stop", config.comfy_service_name)
if rc != 0:
# Stop timed out or failed — force-kill so we can start fresh.
await msg.edit(content="⏳ Graceful stop failed — force-killing process…")
await _kill_service_process(config.comfy_service_name)
await asyncio.sleep(2)
# Step 2: verify stopped before starting.
state = await get_service_state(config.comfy_service_name)
if state not in ("SERVICE_STOPPED", "error", "unknown", "timeout"):
# Still not fully stopped — try one more force-kill.
await _kill_service_process(config.comfy_service_name)
await asyncio.sleep(2)
# Step 3: start.
await msg.edit(content=f"⏳ Starting `{config.comfy_service_name}`…")
rc, out = await _nssm("start", config.comfy_service_name)
if rc != 0:
await msg.edit(content=f"❌ Start failed: `{out}`")
return
# Step 4: wait for HTTP.
await msg.edit(content="⏳ Waiting for ComfyUI to come back up…")
for attempt in range(_MAX_ATTEMPTS):
await asyncio.sleep(_POLL_INTERVAL)
if await _is_comfy_up(config.comfy_server):
await msg.edit(
content=f"✅ ComfyUI is back up! (took ~{(attempt + 1) * _POLL_INTERVAL}s)"
)
return
await msg.edit(content="⚠️ Service started but ComfyUI did not respond within 120 seconds.")
@server_group.command(name="status")
async def server_status(ctx: commands.Context) -> None:
"""Show NSSM service state and HTTP reachability."""
if config is None:
await ctx.reply("Bot config not available.", mention_author=False)
return
state, http_up = await asyncio.gather(
get_service_state(config.comfy_service_name),
_is_comfy_up(config.comfy_server),
)
emoji = STATUS_EMOJI.get(state, "")
svc_line = f"{emoji} `{state}`"
http_line = (
f"🟢 Responding at `{config.comfy_server}`"
if http_up else
f"🔴 Not responding at `{config.comfy_server}`"
)
await ctx.reply(
f"**ComfyUI Server Status**\n"
f"Service `{config.comfy_service_name}`: {svc_line}\n"
f"HTTP: {http_line}",
mention_author=False,
)
@server_group.command(name="install")
async def server_install(ctx: commands.Context) -> None:
"""(Re)install the ComfyUI NSSM service with current config settings."""
if config is None:
await ctx.reply("Bot config not available.", mention_author=False)
return
msg = await ctx.reply(
f"⏳ Installing service `{config.comfy_service_name}`…", mention_author=False
)
ok, detail = await _install_service(config)
await msg.edit(content=f"{'' if ok else ''} {detail}")
@server_group.command(name="uninstall")
async def server_uninstall(ctx: commands.Context) -> None:
"""Stop and remove the ComfyUI NSSM service from Windows."""
if config is None:
await ctx.reply("Bot config not available.", mention_author=False)
return
msg = await ctx.reply(
f"⏳ Removing service `{config.comfy_service_name}`…", mention_author=False
)
await _nssm("stop", config.comfy_service_name)
await _kill_service_process(config.comfy_service_name)
rc, out = await _nssm("remove", config.comfy_service_name, "confirm")
if rc == 0:
await msg.edit(content=f"✅ Service `{config.comfy_service_name}` removed.")
else:
await msg.edit(content=f"❌ `{out}`")

268
commands/utility.py Normal file
View File

@@ -0,0 +1,268 @@
"""
commands/utility.py
===================
Quality-of-life utility commands for the Discord ComfyUI bot.
Commands provided:
- ping: Show bot latency (Discord WebSocket round-trip).
- status: Full overview of bot health, ComfyUI connectivity,
workflow state, and queue.
- queue-status: Quick view of pending job count and worker state.
- uptime: How long the bot has been running since it connected.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from discord.ext import commands
logger = logging.getLogger(__name__)
def _format_uptime(start_time: datetime) -> str:
"""Return a human-readable uptime string from a UTC start datetime."""
delta = datetime.now(timezone.utc) - start_time
total_seconds = int(delta.total_seconds())
days, remainder = divmod(total_seconds, 86400)
hours, remainder = divmod(remainder, 3600)
minutes, seconds = divmod(remainder, 60)
if days:
return f"{days}d {hours}h {minutes}m {seconds}s"
if hours:
return f"{hours}h {minutes}m {seconds}s"
return f"{minutes}m {seconds}s"
def setup_utility_commands(bot):
"""
Register quality-of-life utility commands with the bot.
Parameters
----------
bot : commands.Bot
The Discord bot instance.
"""
@bot.command(name="ping", extras={"category": "Utility"})
async def ping_command(ctx: commands.Context) -> None:
"""
Show the bot's current Discord WebSocket latency.
Usage:
ttr!ping
"""
latency_ms = round(bot.latency * 1000)
await ctx.reply(f"Pong! Latency: **{latency_ms} ms**", mention_author=False)
@bot.command(name="status", extras={"category": "Utility"})
async def status_command(ctx: commands.Context) -> None:
"""
Show a full health overview of the bot and ComfyUI.
Displays:
- Bot latency and uptime
- ComfyUI server address and reachability
- Whether a workflow template is loaded
- Current workflow changes (prompt / negative_prompt / input_image)
- Job queue size and worker state
Usage:
ttr!status
"""
latency_ms = round(bot.latency * 1000)
# Uptime
if hasattr(bot, "start_time") and bot.start_time:
uptime_str = _format_uptime(bot.start_time)
else:
uptime_str = "N/A"
# ComfyUI info
comfy_ok = hasattr(bot, "comfy") and bot.comfy is not None
comfy_server = bot.comfy.server_address if comfy_ok else "not configured"
comfy_reachable = await bot.comfy.check_connection() if comfy_ok else False
workflow_loaded = comfy_ok and bot.comfy.get_workflow_template() is not None
# ComfyUI queue
comfy_pending = 0
comfy_running = 0
if comfy_ok:
q = await bot.comfy.get_comfy_queue()
if q:
comfy_pending = len(q.get("queue_pending", []))
comfy_running = len(q.get("queue_running", []))
# Workflow state summary
changes_parts: list[str] = []
if comfy_ok:
overrides = bot.comfy.state_manager.get_overrides()
if overrides.get("prompt"):
changes_parts.append("prompt")
if overrides.get("negative_prompt"):
changes_parts.append("negative_prompt")
if overrides.get("input_image"):
changes_parts.append(f"input_image: {overrides['input_image']}")
if overrides.get("seed") is not None:
changes_parts.append(f"seed={overrides['seed']}")
changes_summary = ", ".join(changes_parts) if changes_parts else "none"
conn_status = (
"reachable" if comfy_reachable
else ("unreachable" if comfy_ok else "not configured")
)
lines = [
"**Bot**",
f" Latency : {latency_ms} ms",
f" Uptime : {uptime_str}",
"",
f"**ComfyUI** — `{comfy_server}`",
f" Connection : {conn_status}",
f" Queue : {comfy_running} running, {comfy_pending} pending",
f" Workflow : {'loaded' if workflow_loaded else 'not loaded'}",
f" Changes set : {changes_summary}",
]
await ctx.reply("\n".join(lines), mention_author=False)
@bot.command(name="queue-status", aliases=["qs", "qstatus"], extras={"category": "Utility"})
async def queue_status_command(ctx: commands.Context) -> None:
"""
Show the current ComfyUI queue depth.
Usage:
ttr!queue-status
ttr!qs
"""
if not hasattr(bot, "comfy") or not bot.comfy:
await ctx.reply("ComfyUI client is not configured.", mention_author=False)
return
q = await bot.comfy.get_comfy_queue()
if q is None:
await ctx.reply("Could not reach ComfyUI server.", mention_author=False)
return
pending = len(q.get("queue_pending", []))
running = len(q.get("queue_running", []))
await ctx.reply(
f"ComfyUI queue: **{running}** running, **{pending}** pending.",
mention_author=False,
)
@bot.command(name="uptime", extras={"category": "Utility"})
async def uptime_command(ctx: commands.Context) -> None:
"""
Show how long the bot has been running since it last connected.
Usage:
ttr!uptime
"""
if not hasattr(bot, "start_time") or not bot.start_time:
await ctx.reply("Uptime information is not available.", mention_author=False)
return
uptime_str = _format_uptime(bot.start_time)
await ctx.reply(f"Uptime: **{uptime_str}**", mention_author=False)
@bot.command(name="comfy-stats", aliases=["cstats"], extras={"category": "Utility"})
async def comfy_stats_command(ctx: commands.Context) -> None:
"""
Show GPU and system stats from the ComfyUI server.
Displays OS, Python version, and per-device VRAM usage reported
by the ComfyUI ``/system_stats`` endpoint.
Usage:
ttr!comfy-stats
ttr!cstats
"""
if not hasattr(bot, "comfy") or not bot.comfy:
await ctx.reply("ComfyUI client is not configured.", mention_author=False)
return
stats = await bot.comfy.get_system_stats()
if stats is None:
await ctx.reply(
"Could not reach the ComfyUI server to fetch stats.", mention_author=False
)
return
system = stats.get("system", {})
devices = stats.get("devices", [])
lines = [
f"**ComfyUI System Stats** — `{bot.comfy.server_address}`",
f" OS : {system.get('os', 'N/A')}",
f" Python : {system.get('python_version', 'N/A')}",
]
if devices:
lines.append("")
lines.append("**Devices**")
for dev in devices:
name = dev.get("name", "unknown")
vram_total = dev.get("vram_total", 0)
vram_free = dev.get("vram_free", 0)
vram_used = vram_total - vram_free
def _mb(b: int) -> str:
return f"{b / 1024 / 1024:.0f} MB"
lines.append(
f" {name}{_mb(vram_used)} / {_mb(vram_total)} VRAM used"
)
else:
lines.append(" No device info available.")
await ctx.reply("\n".join(lines), mention_author=False)
@bot.command(name="comfy-queue", aliases=["cqueue", "cq"], extras={"category": "Utility"})
async def comfy_queue_command(ctx: commands.Context) -> None:
"""
Show the ComfyUI server's internal queue state.
Displays jobs currently running and pending on the ComfyUI server
itself (separate from the Discord bot's own job queue).
Usage:
ttr!comfy-queue
ttr!cq
"""
if not hasattr(bot, "comfy") or not bot.comfy:
await ctx.reply("ComfyUI client is not configured.", mention_author=False)
return
queue_data = await bot.comfy.get_comfy_queue()
if queue_data is None:
await ctx.reply(
"Could not reach the ComfyUI server to fetch queue info.", mention_author=False
)
return
running = queue_data.get("queue_running", [])
pending = queue_data.get("queue_pending", [])
lines = [
f"**ComfyUI Server Queue** — `{bot.comfy.server_address}`",
f" Running : {len(running)} job(s)",
f" Pending : {len(pending)} job(s)",
]
if running:
lines.append("")
lines.append("**Currently running**")
for entry in running[:5]: # cap at 5 to avoid huge messages
prompt_id = entry[1] if len(entry) > 1 else "unknown"
lines.append(f" `{prompt_id}`")
if pending:
lines.append("")
lines.append(f"**Pending** (showing up to 5 of {len(pending)})")
for entry in pending[:5]:
prompt_id = entry[1] if len(entry) > 1 else "unknown"
lines.append(f" `{prompt_id}`")
await ctx.reply("\n".join(lines), mention_author=False)

100
commands/workflow.py Normal file
View File

@@ -0,0 +1,100 @@
"""
commands/workflow.py
====================
Workflow management commands for the Discord ComfyUI bot.
This module contains commands for loading and managing workflow templates.
"""
from __future__ import annotations
import json
import logging
from typing import Optional, Dict
from discord.ext import commands
from discord_utils import require_comfy_client
logger = logging.getLogger(__name__)
def setup_workflow_commands(bot):
"""
Register workflow management commands with the bot.
Parameters
----------
bot : commands.Bot
The Discord bot instance.
"""
@bot.command(name="workflow-load", aliases=["workflowload", "wfl"], extras={"category": "Workflow"})
@require_comfy_client
async def load_workflow_command(ctx: commands.Context, *, path: Optional[str] = None) -> None:
"""
Load a ComfyUI workflow from a JSON file.
Usage:
ttr!workflow-load path/to/workflow.json
You can also attach a JSON file to the command message instead of
providing a path. The loaded workflow will replace the current
workflow template used by the bot. After loading a workflow you
can generate images with your prompts while reusing the loaded
graph structure.
"""
workflow_data: Optional[Dict] = None
# Check for attached JSON file first
for attachment in ctx.message.attachments:
if attachment.filename.lower().endswith(".json"):
raw = await attachment.read()
try:
text = raw.decode("utf-8")
except UnicodeDecodeError as exc:
await ctx.reply(
f"`{attachment.filename}` is not valid UTF-8: {exc}",
mention_author=False,
)
return
try:
workflow_data = json.loads(text)
break
except json.JSONDecodeError as exc:
await ctx.reply(
f"Failed to parse `{attachment.filename}` as JSON: {exc}",
mention_author=False,
)
return
# Otherwise try to load from provided path
if workflow_data is None and path:
try:
with open(path, "r", encoding="utf-8") as f:
workflow_data = json.load(f)
except FileNotFoundError:
await ctx.reply(f"File not found: `{path}`", mention_author=False)
return
except json.JSONDecodeError as exc:
await ctx.reply(f"Invalid JSON in `{path}`: {exc}", mention_author=False)
return
if workflow_data is None:
await ctx.reply(
"Please provide a JSON workflow file either as an attachment or a path.",
mention_author=False,
)
return
# Set the workflow on the client
try:
bot.comfy.set_workflow(workflow_data)
await ctx.reply("Workflow loaded successfully.", mention_author=False)
except Exception as exc:
await ctx.reply(
f"Failed to set workflow: {type(exc).__name__}: {exc}",
mention_author=False,
)

View File

@@ -0,0 +1,252 @@
"""
commands/workflow_changes.py
============================
Workflow override management commands for the Discord ComfyUI bot.
Works with any NodeInput.key discovered by WorkflowInspector — not just
the four original hard-coded keys. Backward-compat aliases are preserved:
``type:prompt``, ``type:negative_prompt``, ``type:input_image``,
``type:seed``.
"""
from __future__ import annotations
import logging
from discord.ext import commands
from config import ARG_TYPE_KEY
from discord_utils import require_comfy_client
logger = logging.getLogger(__name__)
def setup_workflow_changes_commands(bot):
"""Register workflow changes commands with the bot."""
@bot.command(
name="get-current-workflow-changes",
aliases=["getworkflowchanges", "gcwc"],
extras={"category": "Workflow"},
)
@require_comfy_client
async def get_current_workflow_changes_command(
ctx: commands.Context, *, args: str = ""
) -> None:
"""
Show current workflow override values.
Usage::
ttr!get-current-workflow-changes type:all
ttr!get-current-workflow-changes type:prompt
ttr!get-current-workflow-changes type:<any_override_key>
"""
try:
overrides = bot.comfy.state_manager.get_overrides()
if ARG_TYPE_KEY not in args:
await ctx.reply(
f"Use `{ARG_TYPE_KEY}all` to see all overrides, or "
f"`{ARG_TYPE_KEY}<key>` for a specific key.",
mention_author=False,
)
return
param = args.split(ARG_TYPE_KEY, 1)[1].strip().lower()
if param == "all":
if not overrides:
await ctx.reply("No overrides set.", mention_author=False)
return
lines = [f"**{k}**: `{v}`" for k, v in sorted(overrides.items())]
await ctx.reply(
"Current overrides:\n" + "\n".join(lines),
mention_author=False,
)
else:
# Support multi-word value with the key as prefix
key = param.split()[0] if " " in param else param
val = overrides.get(key)
if val is None:
await ctx.reply(
f"Override `{key}` is not set.",
mention_author=False,
)
else:
await ctx.reply(
f"**{key}**: `{val}`",
mention_author=False,
)
except Exception as exc:
logger.exception("Failed to get workflow overrides")
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)
@bot.command(
name="set-current-workflow-changes",
aliases=["setworkflowchanges", "scwc"],
extras={"category": "Workflow"},
)
@require_comfy_client
async def set_current_workflow_changes_command(
ctx: commands.Context, *, args: str = ""
) -> None:
"""
Set a workflow override value.
Supports any NodeInput.key discovered by WorkflowInspector as well
as the legacy fixed keys.
Usage::
ttr!set-current-workflow-changes type:<key> <value>
Examples::
ttr!scwc type:prompt A beautiful landscape
ttr!scwc type:negative_prompt blurry
ttr!scwc type:input_image my_image.png
ttr!scwc type:steps 30
ttr!scwc type:cfg 7.5
ttr!scwc type:seed 42
"""
try:
if not args or ARG_TYPE_KEY not in args:
await ctx.reply(
f"Usage: `ttr!set-current-workflow-changes {ARG_TYPE_KEY}<key> <value>`",
mention_author=False,
)
return
rest = args.split(ARG_TYPE_KEY, 1)[1]
# Key is the first word; value is everything after the first space
parts = rest.split(None, 1)
if len(parts) < 2:
await ctx.reply(
"Please provide both a key and a value. "
f"Example: `ttr!scwc {ARG_TYPE_KEY}prompt A cat`",
mention_author=False,
)
return
key = parts[0].strip().lower()
raw_value: str = parts[1].strip()
if not key:
await ctx.reply("Key cannot be empty.", mention_author=False)
return
# Type-coerce well-known numeric keys
_int_keys = {"steps", "width", "height"}
_float_keys = {"cfg", "denoise"}
_seed_keys = {"seed", "noise_seed"}
value: object = raw_value
try:
if key in _int_keys:
value = int(raw_value)
elif key in _float_keys:
value = float(raw_value)
elif key in _seed_keys:
value = int(raw_value)
except ValueError:
await ctx.reply(
f"Invalid value for `{key}`: expected a number, got `{raw_value}`.",
mention_author=False,
)
return
bot.comfy.state_manager.set_override(key, value)
await ctx.reply(
f"Override **{key}** set to `{value}`.",
mention_author=False,
)
except Exception as exc:
logger.exception("Failed to set workflow override")
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)
@bot.command(
name="clear-workflow-change",
aliases=["clearworkflowchange", "cwc"],
extras={"category": "Workflow"},
)
@require_comfy_client
async def clear_workflow_change_command(
ctx: commands.Context, *, args: str = ""
) -> None:
"""
Remove a single override key.
Usage::
ttr!clear-workflow-change type:<key>
"""
try:
if ARG_TYPE_KEY not in args:
await ctx.reply(
f"Usage: `ttr!clear-workflow-change {ARG_TYPE_KEY}<key>`",
mention_author=False,
)
return
key = args.split(ARG_TYPE_KEY, 1)[1].strip().lower()
bot.comfy.state_manager.delete_override(key)
await ctx.reply(f"Override **{key}** cleared.", mention_author=False)
except Exception as exc:
logger.exception("Failed to clear override")
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)
@bot.command(
name="set-seed",
aliases=["setseed"],
extras={"category": "Workflow"},
)
@require_comfy_client
async def set_seed_command(ctx: commands.Context, *, args: str = "") -> None:
"""
Pin a specific seed for deterministic generation.
Usage::
ttr!set-seed 42
"""
seed_str = args.strip()
if not seed_str:
await ctx.reply("Usage: `ttr!set-seed <number>`", mention_author=False)
return
if not seed_str.isdigit():
await ctx.reply("Seed must be a non-negative integer.", mention_author=False)
return
seed_val = int(seed_str)
max_seed = 2 ** 32 - 1
if seed_val > max_seed:
await ctx.reply(f"Seed must be between 0 and {max_seed}.", mention_author=False)
return
try:
bot.comfy.state_manager.set_seed(seed_val)
await ctx.reply(f"Seed pinned to `{seed_val}`.", mention_author=False)
except Exception as exc:
logger.exception("Failed to set seed")
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)
@bot.command(
name="clear-seed",
aliases=["clearseed"],
extras={"category": "Workflow"},
)
@require_comfy_client
async def clear_seed_command(ctx: commands.Context) -> None:
"""
Clear the pinned seed and return to random generation.
Usage::
ttr!clear-seed
"""
try:
bot.comfy.state_manager.clear_seed()
await ctx.reply("Seed cleared; generation will now use random seeds.", mention_author=False)
except Exception as exc:
logger.exception("Failed to clear seed")
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)

283
config.py Normal file
View File

@@ -0,0 +1,283 @@
"""
config.py
=========
Configuration module for the Discord ComfyUI bot.
This module centralizes all constants, magic strings, and environment
variable loading to make configuration management easier and more maintainable.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
# ========================================
# Command and Argument Constants
# ========================================
COMMAND_PREFIX = os.getenv("BOT_PREFIX", "ttr!")
"""The command prefix used for Discord bot commands."""
ARG_PROMPT_KEY = "prompt:"
"""The keyword marker for prompt arguments in commands."""
ARG_NEG_PROMPT_KEY = "negative_prompt:"
"""The keyword marker for negative prompt arguments in commands."""
ARG_TYPE_KEY = "type:"
"""The keyword marker for type arguments in commands."""
ARG_QUEUE_KEY = "queue:"
"""The keyword marker for queue count arguments in commands."""
# ========================================
# Discord and Message Constants
# ========================================
MAX_IMAGES_PER_RESPONSE = 4
"""Maximum number of images to include in a single Discord response."""
DEFAULT_UPLOAD_TYPE = "input"
"""Default folder type for ComfyUI image uploads."""
MESSAGE_AUTO_DELETE_TIMEOUT = 60.0
"""Default timeout in seconds for auto-deleting temporary messages."""
# ========================================
# Error Messages
# ========================================
COMFY_NOT_CONFIGURED_MSG = "ComfyUI client is not configured. Please set environment variables."
"""Error message displayed when ComfyUI client is not properly configured."""
# ========================================
# Default Configuration Values
# ========================================
DEFAULT_COMFY_HISTORY_LIMIT = 10
"""Default number of generation history entries to keep."""
# Resolve paths relative to this file's location so both the bot project and
# the portable ComfyUI folder only need to share the same parent directory.
# Layout assumed:
# <parent>/
# ComfyUI_windows_portable/ComfyUI/output ← default output
# ComfyUI_windows_portable/ComfyUI/input ← default input
# the-third-rev/ ← this project
_COMFY_PORTABLE_ROOT = Path(__file__).resolve().parent.parent / "ComfyUI_windows_portable" / "ComfyUI"
DEFAULT_COMFY_OUTPUT_PATH = str(_COMFY_PORTABLE_ROOT / "output")
DEFAULT_COMFY_INPUT_PATH = str(_COMFY_PORTABLE_ROOT / "input")
# ========================================
# Configuration Class
# ========================================
@dataclass
class BotConfig:
"""
Configuration container for the Discord ComfyUI bot.
This dataclass holds all configuration values loaded from environment
variables. Use the `from_env()` class method to create an instance
with values loaded from the environment.
Attributes
----------
discord_bot_token : str
Discord bot authentication token (required).
comfy_server : str
ComfyUI server address in format "hostname:port" (required).
comfy_output_path : str
Path to ComfyUI output directory for reading generated files.
comfy_history_limit : int
Number of generation history entries to keep in memory.
workflow_file : Optional[str]
Path to a workflow JSON file to load at startup (optional).
"""
discord_bot_token: str
comfy_server: str
comfy_output_path: str
comfy_input_path: str
comfy_history_limit: int
comfy_input_channel_id: int = 1475791295665405962
comfy_service_name: str = "ComfyUI"
comfy_start_bat: str = ""
comfy_log_dir: str = ""
comfy_log_max_mb: int = 10
comfy_autostart: bool = True
workflow_file: Optional[str] = None
log_channel_id: Optional[int] = None
zip_password: Optional[str] = None
media_upload_user: Optional[str] = None
media_upload_pass: Optional[str] = None
# Web UI fields
web_enabled: bool = True
web_host: str = "0.0.0.0"
web_port: int = 8080
web_secret_key: str = ""
web_token_file: str = "invite_tokens.json"
web_jwt_expire_hours: int = 8
web_secure_cookie: bool = True
admin_password: Optional[str] = None
@classmethod
def from_env(cls) -> BotConfig:
"""
Create a BotConfig instance by loading values from environment variables.
Environment Variables
---------------------
DISCORD_BOT_TOKEN : str (required)
Discord bot authentication token.
COMFY_SERVER : str (required)
ComfyUI server address (e.g., "localhost:8188" or "example.com:8188").
COMFY_OUTPUT_PATH : str (optional)
Path to ComfyUI output directory. Defaults to DEFAULT_COMFY_OUTPUT_PATH
if not specified.
COMFY_HISTORY_LIMIT : int (optional)
Number of generation history entries to keep. Defaults to
DEFAULT_COMFY_HISTORY_LIMIT if not specified or invalid.
WORKFLOW_FILE : str (optional)
Path to a workflow JSON file to load at startup.
Returns
-------
BotConfig
A configured BotConfig instance.
Raises
------
RuntimeError
If required environment variables (DISCORD_BOT_TOKEN or COMFY_SERVER)
are not set.
"""
# Load required variables
discord_token = os.getenv("DISCORD_BOT_TOKEN")
if not discord_token:
raise RuntimeError(
"DISCORD_BOT_TOKEN environment variable is required. "
"Please set it in your .env file or environment."
)
comfy_server = os.getenv("COMFY_SERVER")
if not comfy_server:
raise RuntimeError(
"COMFY_SERVER environment variable is required. "
"Please set it in your .env file or environment."
)
# Load optional variables with defaults
comfy_output_path = os.getenv("COMFY_OUTPUT_PATH", DEFAULT_COMFY_OUTPUT_PATH)
comfy_input_path = os.getenv("COMFY_INPUT_PATH", DEFAULT_COMFY_INPUT_PATH)
# Parse history limit with fallback to default
try:
comfy_history_limit = int(os.getenv("COMFY_HISTORY_LIMIT", str(DEFAULT_COMFY_HISTORY_LIMIT)))
except ValueError:
comfy_history_limit = DEFAULT_COMFY_HISTORY_LIMIT
workflow_file = os.getenv("WORKFLOW_FILE")
log_channel_id_str = os.getenv("LOG_CHANNEL_ID", "1475408462740721809")
try:
log_channel_id = int(log_channel_id_str) if log_channel_id_str else None
except ValueError:
log_channel_id = None
zip_password = os.getenv("ZIP_PASSWORD", "0Revel512796@")
media_upload_user = os.getenv("MEDIA_UPLOAD_USER") or None
media_upload_pass = os.getenv("MEDIA_UPLOAD_PASS") or None
try:
comfy_input_channel_id = int(os.getenv("COMFY_INPUT_CHANNEL_ID", "1475791295665405962"))
except ValueError:
comfy_input_channel_id = 1475791295665405962
comfy_service_name = os.getenv("COMFY_SERVICE_NAME", "ComfyUI")
default_bat = str(_COMFY_PORTABLE_ROOT.parent / "run_nvidia_gpu.bat")
comfy_start_bat = os.getenv("COMFY_START_BAT", default_bat)
default_log_dir = str(_COMFY_PORTABLE_ROOT.parent / "logs")
comfy_log_dir = os.getenv("COMFY_LOG_DIR", default_log_dir)
try:
comfy_log_max_mb = int(os.getenv("COMFY_LOG_MAX_MB", "10"))
except ValueError:
comfy_log_max_mb = 10
comfy_autostart = os.getenv("COMFY_AUTOSTART", "true").lower() not in ("false", "0", "no")
# Web UI config
web_enabled = os.getenv("WEB_ENABLED", "true").lower() not in ("false", "0", "no")
web_host = os.getenv("WEB_HOST", "0.0.0.0")
try:
web_port = int(os.getenv("WEB_PORT", "8080"))
except ValueError:
web_port = 8080
web_secret_key = os.getenv("WEB_SECRET_KEY", "")
web_token_file = os.getenv("WEB_TOKEN_FILE", "invite_tokens.json")
try:
web_jwt_expire_hours = int(os.getenv("WEB_JWT_EXPIRE_HOURS", "8"))
except ValueError:
web_jwt_expire_hours = 8
web_secure_cookie = os.getenv("WEB_SECURE_COOKIE", "true").lower() not in ("false", "0", "no")
admin_password = os.getenv("ADMIN_PASSWORD") or None
return cls(
discord_bot_token=discord_token,
comfy_server=comfy_server,
comfy_output_path=comfy_output_path,
comfy_input_path=comfy_input_path,
comfy_history_limit=comfy_history_limit,
comfy_input_channel_id=comfy_input_channel_id,
comfy_service_name=comfy_service_name,
comfy_start_bat=comfy_start_bat,
comfy_log_dir=comfy_log_dir,
comfy_log_max_mb=comfy_log_max_mb,
comfy_autostart=comfy_autostart,
workflow_file=workflow_file,
log_channel_id=log_channel_id,
zip_password=zip_password,
media_upload_user=media_upload_user,
media_upload_pass=media_upload_pass,
web_enabled=web_enabled,
web_host=web_host,
web_port=web_port,
web_secret_key=web_secret_key,
web_token_file=web_token_file,
web_jwt_expire_hours=web_jwt_expire_hours,
web_secure_cookie=web_secure_cookie,
admin_password=admin_password,
)
def __repr__(self) -> str:
"""Return a string representation with sensitive data masked."""
return (
f"BotConfig("
f"discord_bot_token='***masked***', "
f"comfy_server='{self.comfy_server}', "
f"comfy_output_path='{self.comfy_output_path}', "
f"comfy_input_path='{self.comfy_input_path}', "
f"comfy_history_limit={self.comfy_history_limit}, "
f"comfy_input_channel_id={self.comfy_input_channel_id}, "
f"workflow_file={self.workflow_file!r}, "
f"log_channel_id={self.log_channel_id!r}, "
f"zip_password={'***masked***' if self.zip_password else None})"
)

251
discord_utils.py Normal file
View File

@@ -0,0 +1,251 @@
"""
discord_utils.py
================
Discord utility functions and helpers for the Discord ComfyUI bot.
This module provides reusable Discord-specific utilities including:
- Command decorators for validation
- Argument parsing helpers
- Message formatting utilities
- Discord UI components
"""
from __future__ import annotations
import functools
from io import BytesIO
from typing import Dict, Optional, List, Tuple
import discord
from discord.ext import commands
from discord.ui import View
from config import COMFY_NOT_CONFIGURED_MSG
def require_comfy_client(func):
"""
Decorator that validates bot.comfy exists before executing a command.
This decorator checks if the bot has a configured ComfyClient instance
(bot.comfy) and sends an error message if not. This eliminates the need
for repeated validation code in every command.
Usage
-----
@bot.command(name="generate")
@require_comfy_client
async def generate_command(ctx: commands.Context, *, args: str = ""):
# bot.comfy is guaranteed to exist here
await bot.comfy.generate_image(...)
Parameters
----------
func : callable
The command function to wrap.
Returns
-------
callable
The wrapped command function with ComfyClient validation.
"""
@functools.wraps(func)
async def wrapper(ctx: commands.Context, *args, **kwargs):
bot = ctx.bot
if not hasattr(bot, "comfy") or bot.comfy is None:
await ctx.reply(COMFY_NOT_CONFIGURED_MSG, mention_author=False)
return
return await func(ctx, *args, **kwargs)
return wrapper
def parse_labeled_args(args: str, keys: List[str]) -> Dict[str, Optional[str]]:
"""
Parse labeled arguments from a command string.
This parser handles Discord command arguments in the format:
"key1:value1 key2:value2 ..."
The parser splits on keyword markers and preserves case in values.
If a key is not found in the args string, its value will be None.
Parameters
----------
args : str
The argument string to parse.
keys : List[str]
List of keys to extract (e.g., ["prompt:", "negative_prompt:"]).
Returns
-------
Dict[str, Optional[str]]
Dictionary mapping keys (without colons) to their values.
Keys not found in args will have None values.
Examples
--------
>>> parse_labeled_args("prompt:a cat negative_prompt:blurry", ["prompt:", "negative_prompt:"])
{"prompt": "a cat", "negative_prompt": "blurry"}
>>> parse_labeled_args("prompt:hello world", ["prompt:", "type:"])
{"prompt": "hello world", "type": None}
"""
result = {key.rstrip(":"): None for key in keys}
remaining = args
# Sort keys by position in string to parse left-to-right
found_keys = []
for key in keys:
if key in remaining:
idx = remaining.find(key)
found_keys.append((idx, key))
found_keys.sort()
for i, (_, key) in enumerate(found_keys):
# Split on this key
parts = remaining.split(key, 1)
if len(parts) < 2:
continue
value_part = parts[1]
# Find the next key, if any
next_key_idx = len(value_part)
if i + 1 < len(found_keys):
next_key = found_keys[i + 1][1]
if next_key in value_part:
next_key_idx = value_part.find(next_key)
# Extract value up to next key
value = value_part[:next_key_idx].strip()
result[key.rstrip(":")] = value if value else None
return result
def convert_image_bytes_to_discord_files(
images: List[bytes], max_files: int = 4, prefix: str = "generated"
) -> List[discord.File]:
"""
Convert a list of image bytes to Discord File objects.
Parameters
----------
images : List[bytes]
List of raw image data as bytes.
max_files : int
Maximum number of files to convert (default: 4, Discord's limit).
prefix : str
Filename prefix for generated files (default: "generated").
Returns
-------
List[discord.File]
List of Discord.File objects ready to send.
"""
files = []
for idx, img_bytes in enumerate(images):
if idx >= max_files:
break
file_obj = BytesIO(img_bytes)
file_obj.seek(0)
files.append(discord.File(file_obj, filename=f"{prefix}_{idx + 1}.png"))
return files
async def send_queue_status(ctx: commands.Context, queue_size: int) -> None:
"""
Send a queue status message to the channel.
Parameters
----------
ctx : commands.Context
The command context.
queue_size : int
Current number of jobs in the queue.
"""
await ctx.send(f"Queue size: {queue_size}", mention_author=False)
async def send_typing_with_callback(ctx: commands.Context, callback):
"""
Execute a callback while showing typing indicator.
Parameters
----------
ctx : commands.Context
The command context.
callback : callable
Async function to execute while typing.
Returns
-------
Any
The return value of the callback.
"""
async with ctx.typing():
return await callback()
def truncate_text(text: str, length: int = 50) -> str:
"""
Truncate text to a maximum length with ellipsis.
Parameters
----------
text : str
The text to truncate.
length : int
Maximum length (default: 50).
Returns
-------
str
Truncated text with "..." suffix if longer than length.
"""
return text if len(text) <= length else text[: length - 3] + "..."
def extract_arg_value(args: str, key: str) -> Tuple[Optional[str], str]:
"""
Extract a single argument value from a labeled args string.
This is a simpler alternative to parse_labeled_args for extracting just
one value.
Parameters
----------
args : str
The full argument string.
key : str
The key to extract (e.g., "type:").
Returns
-------
Tuple[Optional[str], str]
A tuple of (extracted_value, remaining_args). If key not found,
returns (None, original_args).
Examples
--------
>>> extract_arg_value("type:input some other text", "type:")
("input", "some other text")
"""
if key not in args:
return None, args
parts = args.split(key, 1)
if len(parts) < 2:
return None, args
value_and_rest = parts[1].strip()
# Take first word as value
words = value_and_rest.split(None, 1)
value = words[0] if words else None
remaining = words[1] if len(words) > 1 else ""
return value, remaining

12
frontend/index.html Normal file
View File

@@ -0,0 +1,12 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>ComfyUI Bot</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>

2774
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

27
frontend/package.json Normal file
View File

@@ -0,0 +1,27 @@
{
"name": "comfyui-bot-ui",
"version": "0.1.0",
"private": true,
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc && vite build",
"preview": "vite preview"
},
"dependencies": {
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-router-dom": "^6.27.0",
"@tanstack/react-query": "^5.62.0"
},
"devDependencies": {
"@types/react": "^18.3.12",
"@types/react-dom": "^18.3.1",
"@vitejs/plugin-react": "^4.3.3",
"autoprefixer": "^10.4.20",
"postcss": "^8.4.49",
"tailwindcss": "^3.4.15",
"typescript": "^5.6.3",
"vite": "^5.4.11"
}
}

View File

@@ -0,0 +1,6 @@
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}

68
frontend/src/App.tsx Normal file
View File

@@ -0,0 +1,68 @@
import React from 'react'
import { BrowserRouter, Routes, Route, Navigate } from 'react-router-dom'
import { useAuth } from './hooks/useAuth'
import Layout from './components/Layout'
import { GenerationProvider } from './context/GenerationContext'
import LoginPage from './pages/LoginPage'
import GeneratePage from './pages/GeneratePage'
import InputImagesPage from './pages/InputImagesPage'
import WorkflowPage from './pages/WorkflowPage'
import PresetsPage from './pages/PresetsPage'
import StatusPage from './pages/StatusPage'
import ServerPage from './pages/ServerPage'
import HistoryPage from './pages/HistoryPage'
import AdminPage from './pages/AdminPage'
import SharePage from './pages/SharePage'
function RequireAuth({ children }: { children: React.ReactNode }) {
const { isAuthenticated, isLoading } = useAuth()
if (isLoading) return <div className="flex items-center justify-center h-screen text-gray-500">Loading...</div>
if (!isAuthenticated) return <Navigate to="/login" replace />
return <>{children}</>
}
function RequireAdmin({ children }: { children: React.ReactNode }) {
const { isAdmin, isLoading } = useAuth()
if (isLoading) return null
if (!isAdmin) return <Navigate to="/generate" replace />
return <>{children}</>
}
export default function App() {
return (
<BrowserRouter>
<Routes>
<Route path="/login" element={<LoginPage />} />
<Route
path="/"
element={
<RequireAuth>
<GenerationProvider>
<Layout />
</GenerationProvider>
</RequireAuth>
}
>
<Route index element={<Navigate to="/generate" replace />} />
<Route path="generate" element={<GeneratePage />} />
<Route path="inputs" element={<InputImagesPage />} />
<Route path="workflow" element={<WorkflowPage />} />
<Route path="presets" element={<PresetsPage />} />
<Route path="status" element={<StatusPage />} />
<Route path="server" element={<ServerPage />} />
<Route path="history" element={<HistoryPage />} />
<Route
path="admin"
element={
<RequireAdmin>
<AdminPage />
</RequireAdmin>
}
/>
</Route>
<Route path="/share/:token" element={<SharePage />} />
<Route path="*" element={<Navigate to="/generate" replace />} />
</Routes>
</BrowserRouter>
)
}

198
frontend/src/api/client.ts Normal file
View File

@@ -0,0 +1,198 @@
/** Typed API client for the ComfyUI Bot web API. */
const BASE = '' // same-origin in prod; Vite proxy in dev
async function _fetch<T>(path: string, init?: RequestInit): Promise<T> {
const res = await fetch(BASE + path, {
credentials: 'include',
headers: { 'Content-Type': 'application/json', ...init?.headers },
...init,
})
if (!res.ok) {
const msg = await res.text().catch(() => res.statusText)
throw new Error(`${res.status}: ${msg}`)
}
return res.json() as Promise<T>
}
// Auth
export const authLogin = (token: string) =>
_fetch<{ label: string; admin: boolean }>('/api/auth/login', {
method: 'POST',
body: JSON.stringify({ token }),
})
export const authLogout = () =>
_fetch<{ ok: boolean }>('/api/auth/logout', { method: 'POST' })
export const authMe = () =>
_fetch<{ label: string; admin: boolean }>('/api/auth/me')
// Admin
export const adminLogin = (password: string) =>
_fetch<{ label: string; admin: boolean }>('/api/admin/login', {
method: 'POST',
body: JSON.stringify({ password }),
})
export const adminListTokens = () =>
_fetch<Array<{ id: string; label: string; admin: boolean; created_at: string }>>('/api/admin/tokens')
export const adminCreateToken = (label: string, admin = false) =>
_fetch<{ token: string; label: string; admin: boolean }>('/api/admin/tokens', {
method: 'POST',
body: JSON.stringify({ label, admin }),
})
export const adminRevokeToken = (id: string) =>
_fetch<{ ok: boolean }>(`/api/admin/tokens/${id}`, { method: 'DELETE' })
// Status
export const getStatus = () => _fetch<Record<string, unknown>>('/api/status')
// State / overrides
export const getState = () => _fetch<Record<string, unknown>>('/api/state')
export const putState = (overrides: Record<string, unknown>) =>
_fetch<Record<string, unknown>>('/api/state', {
method: 'PUT',
body: JSON.stringify(overrides),
})
export const deleteStateKey = (key: string) =>
_fetch<{ ok: boolean }>(`/api/state/${key}`, { method: 'DELETE' })
// Generation
export interface GenerateRequest {
prompt: string
negative_prompt?: string
overrides?: Record<string, unknown>
}
export const generate = (body: GenerateRequest) =>
_fetch<{ queued: boolean; queue_position: number }>('/api/generate', {
method: 'POST',
body: JSON.stringify(body),
})
export interface WorkflowGenRequest {
count?: number
overrides?: Record<string, unknown>
}
export const workflowGen = (body: WorkflowGenRequest) =>
_fetch<{ queued: boolean; count: number; queue_position: number }>('/api/workflow-gen', {
method: 'POST',
body: JSON.stringify(body),
})
// Inputs
export interface InputImage {
id: number
original_message_id: number
bot_reply_id: number | null
channel_id: number
filename: string
is_active: number
active_slot_key: string | null
}
export const listInputs = () => _fetch<InputImage[]>('/api/inputs')
export const uploadInput = (file: File, slotKey = 'input_image') => {
const form = new FormData()
form.append('file', file)
form.append('slot_key', slotKey)
return fetch('/api/inputs', {
method: 'POST',
credentials: 'include',
body: form,
}).then(r => r.json())
}
export const activateInput = (id: number, slotKey = 'input_image') =>
_fetch<{ ok: boolean }>(`/api/inputs/${id}/activate?slot_key=${slotKey}`, { method: 'POST' })
export const deleteInput = (id: number) =>
_fetch<{ ok: boolean }>(`/api/inputs/${id}`, { method: 'DELETE' })
export const getInputImage = (id: number) => `/api/inputs/${id}/image`
export const getInputThumb = (id: number) => `/api/inputs/${id}/thumb`
export const getInputMid = (id: number) => `/api/inputs/${id}/mid`
// Presets
export interface PresetMeta { name: string; owner: string | null; description: string | null }
export const listPresets = () => _fetch<{ presets: PresetMeta[] }>('/api/presets')
export const savePreset = (name: string, description?: string) =>
_fetch<{ ok: boolean }>('/api/presets', { method: 'POST', body: JSON.stringify({ name, description: description ?? null }) })
export const getPreset = (name: string) => _fetch<Record<string, unknown>>(`/api/presets/${name}`)
export const loadPreset = (name: string) =>
_fetch<{ ok: boolean }>(`/api/presets/${name}/load`, { method: 'POST' })
export const deletePreset = (name: string) =>
_fetch<{ ok: boolean }>(`/api/presets/${name}`, { method: 'DELETE' })
export const savePresetFromHistory = (promptId: string, name: string, description?: string) =>
_fetch<{ ok: boolean; name: string }>(`/api/presets/from-history/${promptId}`, {
method: 'POST',
body: JSON.stringify({ name, description: description ?? null }),
})
// Server
export const getServerStatus = () =>
_fetch<{ service_state: string; http_reachable: boolean }>('/api/server/status')
export const serverAction = (action: string) =>
_fetch<{ ok: boolean }>(`/api/server/${action}`, { method: 'POST' })
export const tailLogs = (lines = 100) =>
_fetch<{ lines: string[] }>(`/api/logs/tail?lines=${lines}`)
// History
export const getHistory = (q?: string) =>
_fetch<{ history: Array<Record<string, unknown>> }>(q ? `/api/history?q=${encodeURIComponent(q)}` : '/api/history')
export const createHistoryShare = (promptId: string) =>
_fetch<{ share_token: string }>(`/api/history/${promptId}/share`, { method: 'POST' })
export const revokeHistoryShare = (promptId: string) =>
_fetch<{ ok: boolean }>(`/api/history/${promptId}/share`, { method: 'DELETE' })
export const getShareFileUrl = (token: string, filename: string) =>
`/api/share/${token}/file/${encodeURIComponent(filename)}`
// Workflow
export const getWorkflow = () =>
_fetch<{ loaded: boolean; node_count: number; last_workflow_file: string | null }>('/api/workflow')
export interface NodeInput {
key: string
label: string
input_type: string
current_value: unknown
node_class: string
node_title: string
is_common: boolean
}
export const getWorkflowInputs = () =>
_fetch<{ common: NodeInput[]; advanced: NodeInput[] }>('/api/workflow/inputs')
export const listWorkflowFiles = () =>
_fetch<{ files: string[] }>('/api/workflow/files')
export const uploadWorkflow = (file: File) => {
const form = new FormData()
form.append('file', file)
return fetch('/api/workflow/upload', {
method: 'POST',
credentials: 'include',
body: form,
}).then(r => r.json())
}
export const loadWorkflow = (filename: string) => {
const form = new FormData()
form.append('filename', filename)
return fetch('/api/workflow/load', {
method: 'POST',
credentials: 'include',
body: form,
}).then(r => r.json())
}
export const getModels = (type: 'checkpoints' | 'loras') =>
_fetch<{ type: string; models: string[] }>(`/api/workflow/models?type=${type}`)

View File

@@ -0,0 +1,329 @@
import React, { useEffect, useState } from 'react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import {
getWorkflowInputs,
getModels,
putState,
deleteStateKey,
getState,
NodeInput,
activateInput,
listInputs,
getInputImage,
getInputThumb,
getInputMid,
} from '../api/client'
import LazyImage from './LazyImage'
interface Props {
/** Called when the Generate button is clicked with the current overrides */
onGenerate: (overrides: Record<string, unknown>, count: number) => void
/** Live seed from WS generation_complete event */
lastSeed?: number | null
generating?: boolean
/** The authenticated user's label (used to find their active input image) */
userLabel?: string
}
export default function DynamicWorkflowForm({ onGenerate, lastSeed, generating, userLabel }: Props) {
const qc = useQueryClient()
const { data: inputsData, isLoading: inputsLoading } = useQuery({
queryKey: ['workflow', 'inputs'],
queryFn: getWorkflowInputs,
})
const { data: stateData } = useQuery({
queryKey: ['state'],
queryFn: getState,
})
const { data: checkpoints } = useQuery({
queryKey: ['models', 'checkpoints'],
queryFn: () => getModels('checkpoints'),
staleTime: 60_000,
})
const { data: loras } = useQuery({
queryKey: ['models', 'loras'],
queryFn: () => getModels('loras'),
staleTime: 60_000,
})
const { data: inputImages } = useQuery({
queryKey: ['inputs'],
queryFn: listInputs,
})
const [localValues, setLocalValues] = useState<Record<string, unknown>>({})
const [randomSeeds, setRandomSeeds] = useState<Record<string, boolean>>({})
const [imagePicker, setImagePicker] = useState<string | null>(null) // key of slot being picked
const [count, setCount] = useState(1)
// Sync local values from state when stateData arrives
useEffect(() => {
if (stateData) setLocalValues(stateData as Record<string, unknown>)
}, [stateData])
// Update seed field when WS reports completed seed
useEffect(() => {
if (lastSeed != null) {
setLocalValues(v => ({ ...v, seed: lastSeed }))
}
}, [lastSeed])
const putStateMut = useMutation({
mutationFn: (overrides: Record<string, unknown>) => putState(overrides),
onSuccess: () => qc.invalidateQueries({ queryKey: ['state'] }),
})
const deleteKeyMut = useMutation({
mutationFn: (key: string) => deleteStateKey(key),
onSuccess: () => qc.invalidateQueries({ queryKey: ['state'] }),
})
const setValue = (key: string, value: unknown) => {
setLocalValues(v => ({ ...v, [key]: value }))
putStateMut.mutate({ [key]: value })
}
const handleActivateImage = async (imageId: number, slotKey: string) => {
await activateInput(imageId, slotKey)
qc.invalidateQueries({ queryKey: ['inputs'] })
qc.invalidateQueries({ queryKey: ['state'] })
setImagePicker(null)
}
const handleGenerate = () => {
const overrides: Record<string, unknown> = {}
const allInputs = [...(inputsData?.common ?? []), ...(inputsData?.advanced ?? [])]
for (const inp of allInputs) {
if (inp.input_type === 'seed') {
overrides[inp.key] = randomSeeds[inp.key] !== false ? -1 : (localValues[inp.key] ?? -1)
} else if (inp.input_type === 'image') {
// image slot — server reads from state_manager
} else {
const v = localValues[inp.key]
if (v !== undefined && v !== '') overrides[inp.key] = v
}
}
onGenerate(overrides, count)
}
if (inputsLoading) return <div className="text-sm text-gray-400">Loading workflow inputs</div>
if (!inputsData) return <div className="text-sm text-gray-400">No workflow loaded.</div>
const renderField = (inp: NodeInput) => {
const val = localValues[inp.key] ?? inp.current_value
if (inp.input_type === 'text') {
return (
<textarea
rows={3}
className="w-full border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 resize-y focus:outline-none focus:ring-1 focus:ring-blue-500"
value={String(val ?? '')}
onChange={e => setValue(inp.key, e.target.value)}
/>
)
}
if (inp.input_type === 'seed') {
const isRandom = randomSeeds[inp.key] !== false
return (
<div className="flex gap-2 items-center">
<input
type="number"
className="flex-1 border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500 disabled:opacity-40"
value={isRandom ? '' : String(val ?? '')}
placeholder={isRandom ? 'Random' : undefined}
disabled={isRandom}
onChange={e => setValue(inp.key, Number(e.target.value))}
/>
<button
type="button"
onClick={() => setRandomSeeds(r => ({ ...r, [inp.key]: !isRandom }))}
className={`text-xs px-2 py-1 rounded border ${isRandom ? 'bg-blue-600 text-white border-blue-600' : 'border-gray-400 text-gray-600 dark:text-gray-300'}`}
>
🎲 Random
</button>
</div>
)
}
if (inp.input_type === 'image') {
const activeFilename = String(val ?? '')
const namespacedKey = userLabel ? `${userLabel}_${inp.key}` : inp.key
const activeImg = inputImages?.find(i => i.active_slot_key === namespacedKey)
return (
<div className="flex items-center gap-2">
{activeImg ? (
<img
src={getInputThumb(activeImg.id)}
alt={activeImg.filename}
className="w-16 h-16 object-cover rounded border border-gray-300 dark:border-gray-600"
/>
) : (
<div className="w-16 h-16 rounded border border-dashed border-gray-400 flex items-center justify-center text-xs text-gray-400">none</div>
)}
<div className="flex flex-col gap-1">
<span className="text-xs text-gray-500 dark:text-gray-400 truncate max-w-[12rem]">{activeFilename || 'No image active'}</span>
<button
type="button"
onClick={() => setImagePicker(imagePicker === inp.key ? null : inp.key)}
className="text-xs bg-gray-200 dark:bg-gray-600 hover:bg-gray-300 dark:hover:bg-gray-500 rounded px-2 py-0.5"
>
Browse
</button>
</div>
</div>
)
}
if (inp.input_type === 'checkpoint') {
return (
<select
className="w-full border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
value={String(val ?? '')}
onChange={e => setValue(inp.key, e.target.value)}
>
<option value=""> select checkpoint </option>
{(checkpoints?.models ?? []).map(m => <option key={m} value={m}>{m}</option>)}
</select>
)
}
if (inp.input_type === 'lora') {
return (
<select
className="w-full border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
value={String(val ?? '')}
onChange={e => setValue(inp.key, e.target.value)}
>
<option value=""> select lora </option>
{(loras?.models ?? []).map(m => <option key={m} value={m}>{m}</option>)}
</select>
)
}
// integer, float, string
return (
<input
type={inp.input_type === 'integer' || inp.input_type === 'float' ? 'number' : 'text'}
step={inp.input_type === 'float' ? 'any' : undefined}
className="w-full border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
value={String(val ?? '')}
placeholder={String(inp.current_value ?? '')}
onChange={e => {
const raw = e.target.value
const coerced = inp.input_type === 'integer' ? parseInt(raw) || raw
: inp.input_type === 'float' ? parseFloat(raw) || raw
: raw
setValue(inp.key, coerced)
}}
/>
)
}
return (
<div className="space-y-4">
{/* Common inputs */}
{inputsData.common.map(inp => (
<div key={inp.key}>
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-1">{inp.label}</label>
{renderField(inp)}
{imagePicker === inp.key && (
<ImagePickerGrid
images={inputImages ?? []}
slotKey={inp.key}
onPick={handleActivateImage}
onClose={() => setImagePicker(null)}
/>
)}
</div>
))}
{/* Advanced inputs */}
{inputsData.advanced.length > 0 && (
<details className="border border-gray-200 dark:border-gray-700 rounded">
<summary className="px-3 py-2 text-sm font-medium cursor-pointer select-none text-gray-700 dark:text-gray-300">
Advanced ({inputsData.advanced.length} inputs)
</summary>
<div className="px-3 pb-3 space-y-3 mt-2">
{inputsData.advanced.map(inp => (
<div key={inp.key}>
<label className="block text-xs font-medium text-gray-600 dark:text-gray-400 mb-1">{inp.label}</label>
{renderField(inp)}
{imagePicker === inp.key && (
<ImagePickerGrid
images={inputImages ?? []}
slotKey={inp.key}
onPick={handleActivateImage}
onClose={() => setImagePicker(null)}
/>
)}
</div>
))}
</div>
</details>
)}
<div className="flex gap-2 items-center">
<input
type="number"
min={1}
max={20}
value={count}
onChange={e => setCount(Math.max(1, Math.min(20, Number(e.target.value))))}
className="w-16 border border-gray-300 dark:border-gray-600 rounded px-2 py-2 text-sm text-center bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
title="Number of generations to queue"
/>
<button
type="button"
onClick={handleGenerate}
disabled={generating}
className="flex-1 bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-4 py-2 text-sm font-semibold transition-colors"
>
{generating ? '⏳ Generating…' : count > 1 ? `Generate ×${count}` : 'Generate'}
</button>
</div>
</div>
)
}
function ImagePickerGrid({
images,
slotKey,
onPick,
onClose,
}: {
images: { id: number; filename: string }[]
slotKey: string
onPick: (id: number, key: string) => void
onClose: () => void
}) {
return (
<div className="mt-2 border border-gray-300 dark:border-gray-600 rounded p-2 bg-gray-50 dark:bg-gray-800">
<div className="flex justify-between items-center mb-2">
<span className="text-xs text-gray-500 dark:text-gray-400">Select image for slot: {slotKey}</span>
<button onClick={onClose} className="text-xs text-gray-400 hover:text-gray-600 dark:hover:text-gray-200"></button>
</div>
{images.length === 0 ? (
<p className="text-xs text-gray-400">No images available. Upload some first.</p>
) : (
<div className="grid grid-cols-3 sm:grid-cols-4 gap-1 max-h-48 overflow-y-auto">
{images.map(img => (
<button
key={img.id}
type="button"
onClick={() => onPick(img.id, slotKey)}
className="relative aspect-square overflow-hidden rounded border border-gray-200 dark:border-gray-600 hover:border-blue-500"
title={img.filename}
>
<LazyImage
thumbSrc={getInputThumb(img.id)}
midSrc={getInputMid(img.id)}
fullSrc={getInputImage(img.id)}
alt={img.filename}
className="w-full h-full"
/>
</button>
))}
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,146 @@
import React, { useEffect, useRef, useState, useCallback } from 'react'
import { NavLink, Outlet, useLocation } from 'react-router-dom'
import { useQueryClient } from '@tanstack/react-query'
import { useAuth } from '../hooks/useAuth'
import { useStatus } from '../hooks/useStatus'
import { useGeneration } from '../context/GenerationContext'
const navItems = [
{ to: '/generate', label: 'Generate' },
{ to: '/inputs', label: 'Input Images' },
{ to: '/workflow', label: 'Workflow' },
{ to: '/presets', label: 'Presets' },
{ to: '/status', label: 'Status' },
{ to: '/server', label: 'Server' },
{ to: '/history', label: 'History' },
]
export default function Layout() {
const { user, logout, isAdmin } = useAuth()
const location = useLocation()
const [sidebarOpen, setSidebarOpen] = useState(false)
const [dark, setDark] = useState(() => {
if (typeof window === 'undefined') return false
const stored = localStorage.getItem('dark-mode')
if (stored !== null) return stored === 'true'
return window.matchMedia('(prefers-color-scheme: dark)').matches
})
// Apply dark class on mount and changes
useEffect(() => {
document.documentElement.classList.toggle('dark', dark)
localStorage.setItem('dark-mode', String(dark))
}, [dark])
// Auto-close sidebar on navigation
useEffect(() => {
setSidebarOpen(false)
}, [location.pathname])
const { pendingCount, decrementPending } = useGeneration()
const queryClient = useQueryClient()
const titleResetTimer = useRef<ReturnType<typeof setTimeout> | null>(null)
const onNodeExecuting = useCallback(() => {
document.title = '⏳ Generating… | ComfyUI Bot'
}, [])
const onGenerationComplete = useCallback(() => {
decrementPending()
queryClient.invalidateQueries({ queryKey: ['history'] })
if (titleResetTimer.current) clearTimeout(titleResetTimer.current)
document.title = 'Done | ComfyUI Bot'
titleResetTimer.current = setTimeout(() => { document.title = 'ComfyUI Bot' }, 5000)
}, [decrementPending, queryClient])
const onGenerationError = useCallback(() => {
decrementPending()
if (titleResetTimer.current) clearTimeout(titleResetTimer.current)
document.title = 'Error | ComfyUI Bot'
titleResetTimer.current = setTimeout(() => { document.title = 'ComfyUI Bot' }, 5000)
}, [decrementPending])
useEffect(() => () => { document.title = 'ComfyUI Bot' }, [])
const { status } = useStatus({ enabled: !!user, onGenerationComplete, onGenerationError, onNodeExecuting })
const comfyReachable = status.comfy?.reachable ?? null
const toggleDark = () => setDark(d => !d)
return (
<div className="flex h-screen overflow-hidden">
{/* Mobile backdrop */}
{sidebarOpen && (
<div
className="fixed inset-0 bg-black/40 z-30 md:hidden"
onClick={() => setSidebarOpen(false)}
/>
)}
{/* Sidebar */}
<aside
className={`fixed md:static inset-y-0 left-0 z-40 w-48 flex-none bg-gray-800 text-gray-100 flex flex-col transition-transform duration-200 ${
sidebarOpen ? 'translate-x-0' : '-translate-x-full md:translate-x-0'
}`}
>
<div className="p-4 font-bold text-lg border-b border-gray-700 flex items-center gap-2">
<span>ComfyUI Bot</span>
<span
title={comfyReachable == null ? 'Connecting…' : comfyReachable ? 'ComfyUI reachable' : 'ComfyUI unreachable'}
className={`ml-auto w-2 h-2 rounded-full flex-none ${comfyReachable == null ? 'bg-gray-500' : comfyReachable ? 'bg-green-400' : 'bg-red-400'}`}
/>
</div>
<nav className="flex-1 overflow-y-auto py-2">
{navItems.map(({ to, label }) => (
<NavLink
key={to}
to={to}
className={({ isActive }) =>
`flex items-center px-4 py-2 text-sm hover:bg-gray-700 transition-colors ${isActive ? 'bg-gray-700 font-medium' : ''}`
}
>
<span className="flex-1">{label}</span>
{to === '/generate' && pendingCount > 0 && (
<span className="ml-2 text-xs bg-yellow-400 text-gray-900 rounded-full px-1.5 py-0.5 leading-none">
{pendingCount}
</span>
)}
</NavLink>
))}
{isAdmin && (
<NavLink
to="/admin"
className={({ isActive }) =>
`block px-4 py-2 text-sm hover:bg-gray-700 transition-colors ${isActive ? 'bg-gray-700 font-medium' : ''}`
}
>
Admin
</NavLink>
)}
</nav>
<div className="p-3 border-t border-gray-700 text-xs flex items-center gap-2">
<span className="flex-1 truncate">{user?.label ?? '...'}</span>
<button onClick={toggleDark} className="hover:text-yellow-300" title="Toggle dark mode">
{dark ? '☀' : '🌙'}
</button>
<button onClick={logout} className="hover:text-red-400" title="Logout">
</button>
</div>
</aside>
{/* Main */}
<main className="flex-1 overflow-y-auto p-3 sm:p-6 bg-gray-50 dark:bg-gray-900">
{/* Hamburger button — mobile only */}
<button
className="md:hidden mb-3 p-1.5 rounded bg-gray-200 dark:bg-gray-700 text-gray-700 dark:text-gray-300 hover:bg-gray-300 dark:hover:bg-gray-600"
onClick={() => setSidebarOpen(true)}
aria-label="Open menu"
>
</button>
<Outlet />
</main>
</div>
)
}

View File

@@ -0,0 +1,59 @@
import React, { useState } from 'react'
interface Props {
thumbSrc: string
midSrc: string
fullSrc: string
alt: string
className?: string
onClick?: () => void
}
/**
* 3-stage progressive image loader:
* stage 0 — blurred tiny thumb (loads instantly)
* stage 1 — clear medium-compressed image (fades in when ready)
* stage 2 — full original image (fades in when ready)
*
* All three requests fire in parallel; stage only advances forward so if full
* arrives before mid we jump straight to stage 2.
*/
export default function LazyImage({ thumbSrc, midSrc, fullSrc, alt, className = '', onClick }: Props) {
const [stage, setStage] = useState(0)
const advance = (to: number) => setStage(s => Math.max(s, to))
return (
<div className={`relative overflow-hidden ${className}`}>
{/* Stage 0: blurred thumbnail */}
<img
src={thumbSrc}
alt=""
aria-hidden="true"
className={`absolute inset-0 w-full h-full object-cover transition-opacity duration-300 ${
stage === 0 ? 'opacity-100 blur-sm scale-105' : 'opacity-0'
}`}
/>
{/* Stage 1: clear mid-resolution */}
<img
src={midSrc}
alt=""
aria-hidden="true"
className={`absolute inset-0 w-full h-full object-cover transition-opacity duration-300 ${
stage === 1 ? 'opacity-100' : 'opacity-0'
}`}
onLoad={() => advance(1)}
/>
{/* Stage 2: full resolution */}
<img
src={fullSrc}
alt={alt}
loading="lazy"
className={`relative w-full h-full object-cover transition-opacity duration-300 ${
stage === 2 ? 'opacity-100' : 'opacity-0'
}`}
onLoad={() => advance(2)}
onClick={onClick}
/>
</div>
)
}

View File

@@ -0,0 +1,22 @@
import React, { createContext, useContext, useState, useCallback } from 'react'
interface Value {
pendingCount: number
addPending: (n?: number) => void
decrementPending: () => void
}
const Ctx = createContext<Value>({
pendingCount: 0,
addPending: () => {},
decrementPending: () => {},
})
export function GenerationProvider({ children }: { children: React.ReactNode }) {
const [pendingCount, setPendingCount] = useState(0)
const addPending = useCallback((n = 1) => setPendingCount(c => c + n), [])
const decrementPending = useCallback(() => setPendingCount(c => Math.max(0, c - 1)), [])
return <Ctx.Provider value={{ pendingCount, addPending, decrementPending }}>{children}</Ctx.Provider>
}
export const useGeneration = () => useContext(Ctx)

View File

@@ -0,0 +1,26 @@
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { authMe, authLogout } from '../api/client'
export function useAuth() {
const qc = useQueryClient()
const { data, isLoading, error } = useQuery({
queryKey: ['auth', 'me'],
queryFn: authMe,
retry: false,
staleTime: 60_000,
})
const logout = async () => {
await authLogout()
qc.clear()
window.location.href = '/login'
}
return {
user: data ?? null,
isLoading,
isAuthenticated: !!data,
isAdmin: data?.admin ?? false,
logout,
}
}

View File

@@ -0,0 +1,71 @@
import { useState, useCallback } from 'react'
import { useWebSocket } from './useWebSocket'
export interface StatusSnapshot {
bot?: { latency_ms: number; uptime: string }
comfy?: {
server: string
reachable?: boolean
queue_pending: number
queue_running: number
workflow_loaded: boolean
last_seed: number | null
total_generated: number
}
overrides?: Record<string, unknown>
service?: { state: string }
upload?: { configured: boolean; running: boolean; total_ok: number; total_fail: number }
}
export interface GenerationResult {
prompt_id: string
seed: number | null
image_count: number
video_count: number
}
interface UseStatusOptions {
enabled?: boolean
onGenerationComplete?: (result: GenerationResult) => void
onGenerationError?: (error: { prompt_id: string | null; error: string }) => void
onNodeExecuting?: (node: string, promptId: string) => void
}
export function useStatus({
enabled = true,
onGenerationComplete,
onGenerationError,
onNodeExecuting,
}: UseStatusOptions) {
const [status, setStatus] = useState<StatusSnapshot>({})
const [executingNode, setExecutingNode] = useState<string | null>(null)
const handleMessage = useCallback(
(msg: { type: string; data: unknown; ts: number }) => {
if (msg.type === 'status_snapshot') {
setStatus(msg.data as StatusSnapshot)
} else if (msg.type === 'node_executing') {
const d = msg.data as { node: string; prompt_id: string }
setExecutingNode(d.node)
onNodeExecuting?.(d.node, d.prompt_id)
} else if (msg.type === 'generation_complete') {
setExecutingNode(null)
onGenerationComplete?.(msg.data as GenerationResult)
} else if (msg.type === 'generation_error') {
setExecutingNode(null)
onGenerationError?.(msg.data as { prompt_id: string | null; error: string })
} else if (msg.type === 'server_state') {
const d = msg.data as { state: string; http_reachable: boolean }
setStatus(prev => ({
...prev,
service: { state: d.state },
}))
}
},
[onGenerationComplete, onGenerationError, onNodeExecuting],
)
useWebSocket({ onMessage: handleMessage, enabled })
return { status, executingNode }
}

View File

@@ -0,0 +1,55 @@
import { useEffect, useRef, useCallback } from 'react'
interface WSOptions {
onMessage: (data: { type: string; data: unknown; ts: number }) => void
enabled?: boolean
}
/** Reconnecting WebSocket hook with exponential backoff. Auth via ttb_session cookie. */
export function useWebSocket({ onMessage, enabled = true }: WSOptions) {
const wsRef = useRef<WebSocket | null>(null)
const backoffRef = useRef(1000)
const timerRef = useRef<ReturnType<typeof setTimeout> | null>(null)
const onMessageRef = useRef(onMessage)
onMessageRef.current = onMessage
const connect = useCallback(() => {
if (!enabled) return
const proto = window.location.protocol === 'https:' ? 'wss' : 'ws'
const url = `${proto}://${window.location.host}/ws`
const ws = new WebSocket(url)
wsRef.current = ws
ws.onmessage = (ev) => {
try {
const parsed = JSON.parse(ev.data)
if (parsed.type !== 'ping') {
onMessageRef.current(parsed)
}
} catch {}
}
ws.onopen = () => {
backoffRef.current = 1000
}
ws.onclose = () => {
if (enabled) {
timerRef.current = setTimeout(() => {
backoffRef.current = Math.min(backoffRef.current * 2, 30_000)
connect()
}, backoffRef.current)
}
}
ws.onerror = () => ws.close()
}, [enabled])
useEffect(() => {
connect()
return () => {
enabled && (wsRef.current?.close())
if (timerRef.current) clearTimeout(timerRef.current)
}
}, [connect, enabled])
}

12
frontend/src/index.css Normal file
View File

@@ -0,0 +1,12 @@
@tailwind base;
@tailwind components;
@tailwind utilities;
:root {
color-scheme: light dark;
}
body {
@apply bg-gray-50 text-gray-900 dark:bg-gray-900 dark:text-gray-100;
font-family: system-ui, -apple-system, sans-serif;
}

19
frontend/src/main.tsx Normal file
View File

@@ -0,0 +1,19 @@
import React from 'react'
import ReactDOM from 'react-dom/client'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import App from './App'
import './index.css'
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: 1, staleTime: 5000 },
},
})
ReactDOM.createRoot(document.getElementById('root')!).render(
<React.StrictMode>
<QueryClientProvider client={queryClient}>
<App />
</QueryClientProvider>
</React.StrictMode>,
)

View File

@@ -0,0 +1,113 @@
import React, { useState } from 'react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { adminListTokens, adminCreateToken, adminRevokeToken } from '../api/client'
export default function AdminPage() {
const qc = useQueryClient()
const [label, setLabel] = useState('')
const [isAdmin, setIsAdmin] = useState(false)
const [newToken, setNewToken] = useState<string | null>(null)
const [createError, setCreateError] = useState('')
const { data: tokens = [], isLoading } = useQuery({
queryKey: ['admin', 'tokens'],
queryFn: adminListTokens,
})
const createMut = useMutation({
mutationFn: ({ label, admin }: { label: string; admin: boolean }) => adminCreateToken(label, admin),
onSuccess: (res) => {
qc.invalidateQueries({ queryKey: ['admin', 'tokens'] })
setNewToken(res.token)
setLabel('')
setIsAdmin(false)
setCreateError('')
},
onError: (err) => setCreateError(err instanceof Error ? err.message : String(err)),
})
const revokeMut = useMutation({
mutationFn: (id: string) => adminRevokeToken(id),
onSuccess: () => qc.invalidateQueries({ queryKey: ['admin', 'tokens'] }),
})
return (
<div className="max-w-xl mx-auto space-y-6">
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Admin Token Management</h1>
{/* Create token */}
<div className="bg-white dark:bg-gray-800 rounded border border-gray-200 dark:border-gray-700 p-4 space-y-3">
<p className="text-sm font-medium text-gray-700 dark:text-gray-300">Create invite token</p>
<div className="flex flex-col sm:flex-row gap-2">
<input
type="text"
value={label}
onChange={e => setLabel(e.target.value)}
placeholder="Label (e.g. alice)"
className="flex-1 border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
<label className="flex items-center gap-1 text-sm text-gray-600 dark:text-gray-400 cursor-pointer">
<input
type="checkbox"
checked={isAdmin}
onChange={e => setIsAdmin(e.target.checked)}
className="rounded"
/>
Admin
</label>
<button
onClick={() => createMut.mutate({ label, admin: isAdmin })}
disabled={!label.trim() || createMut.isPending}
className="bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-3 py-2 text-sm font-medium"
>
Create
</button>
</div>
{createError && <p className="text-red-500 text-sm">{createError}</p>}
</div>
{/* New token display — one-time */}
{newToken && (
<div className="bg-yellow-50 dark:bg-yellow-900/30 border border-yellow-300 dark:border-yellow-700 rounded p-4 space-y-2">
<p className="text-sm font-medium text-yellow-800 dark:text-yellow-200">
New token (copy now shown only once):
</p>
<code className="block text-xs break-all bg-yellow-100 dark:bg-yellow-900/50 rounded p-2 text-yellow-900 dark:text-yellow-100 select-all">
{newToken}
</code>
<button
onClick={() => setNewToken(null)}
className="text-xs text-yellow-600 dark:text-yellow-400 hover:underline"
>
Dismiss
</button>
</div>
)}
{/* Token list */}
{isLoading ? (
<div className="text-sm text-gray-400">Loading tokens</div>
) : tokens.length === 0 ? (
<p className="text-sm text-gray-400">No tokens yet.</p>
) : (
<div className="divide-y divide-gray-200 dark:divide-gray-700 border border-gray-200 dark:border-gray-700 rounded">
{tokens.map(t => (
<div key={t.id} className="flex items-center justify-between px-3 py-2 text-sm">
<div>
<span className="font-medium text-gray-700 dark:text-gray-300">{t.label}</span>
{t.admin && <span className="ml-1 text-xs text-purple-600 dark:text-purple-400">(admin)</span>}
<span className="ml-2 text-xs text-gray-400">{new Date(t.created_at).toLocaleDateString()}</span>
</div>
<button
onClick={() => revokeMut.mutate(t.id)}
className="text-xs text-red-500 hover:text-red-700 dark:hover:text-red-400"
>
Revoke
</button>
</div>
))}
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,250 @@
import React, { useState, useCallback } from 'react'
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { generate, workflowGen, listPresets, loadPreset } from '../api/client'
import { useStatus, GenerationResult } from '../hooks/useStatus'
import { useAuth } from '../hooks/useAuth'
import { useGeneration } from '../context/GenerationContext'
import DynamicWorkflowForm from '../components/DynamicWorkflowForm'
interface Notification {
id: number
type: 'success' | 'error' | 'info'
msg: string
}
export default function GeneratePage() {
const { user } = useAuth()
const { pendingCount, addPending } = useGeneration()
const qc = useQueryClient()
const [generating, setGenerating] = useState(false)
const [notifications, setNotifications] = useState<Notification[]>([])
const [executingNodeDisplay, setExecutingNodeDisplay] = useState<string | null>(null)
const [lastSeed, setLastSeed] = useState<number | null>(null)
const [mode, setMode] = useState<'workflow' | 'prompt'>('workflow')
const [prompt, setPrompt] = useState('')
const [negPrompt, setNegPrompt] = useState('')
const [promptCount, setPromptCount] = useState(1)
const [loadingPreset, setLoadingPreset] = useState(false)
const { data: presetsData } = useQuery({ queryKey: ['presets'], queryFn: listPresets })
const addNotif = useCallback((type: Notification['type'], msg: string, ttl = 8000) => {
const id = Date.now()
setNotifications(n => [...n, { id, type, msg }])
setTimeout(() => setNotifications(n => n.filter(x => x.id !== id)), ttl)
}, [])
const handlePresetLoad = useCallback(async (name: string) => {
if (!name) return
setLoadingPreset(true)
try {
await loadPreset(name)
qc.invalidateQueries({ queryKey: ['state'] })
qc.invalidateQueries({ queryKey: ['workflowInputs'] })
addNotif('info', `Loaded preset: ${name}`, 3000)
} catch (err: unknown) {
addNotif('error', `Failed to load preset: ${err instanceof Error ? err.message : String(err)}`)
} finally {
setLoadingPreset(false)
}
}, [qc, addNotif])
const dismissNotif = useCallback((id: number) => {
setNotifications(n => n.filter(x => x.id !== id))
}, [])
const onGenerationComplete = useCallback((r: GenerationResult) => {
setExecutingNodeDisplay(null)
setLastSeed(r.seed)
addNotif('success', `Done — seed: ${r.seed ?? 'unknown'} · ${r.image_count} image(s) · ${r.video_count} video(s)`)
}, [addNotif])
const onGenerationError = useCallback((e: { prompt_id: string | null; error: string }) => {
setExecutingNodeDisplay(null)
addNotif('error', e.error)
}, [addNotif])
const onNodeExecuting = useCallback((node: string) => {
setExecutingNodeDisplay(node)
}, [])
const { status } = useStatus({
enabled: !!user,
onGenerationComplete,
onGenerationError,
onNodeExecuting,
})
const handleWorkflowGenerate = async (overrides: Record<string, unknown>, count: number = 1) => {
setGenerating(true)
document.title = '⏳ Queued… | ComfyUI Bot'
try {
const res = await workflowGen({ count, overrides })
setGenerating(false)
addPending(res.count ?? count)
addNotif('info', `Queued ${res.count ?? count} generation(s) at position ${res.queue_position}`, 3000)
} catch (err: unknown) {
setGenerating(false)
addNotif('error', err instanceof Error ? err.message : String(err))
document.title = 'ComfyUI Bot'
}
}
const handlePromptGenerate = async () => {
const n = Math.max(1, Math.min(20, promptCount))
setGenerating(true)
document.title = '⏳ Queued… | ComfyUI Bot'
try {
let lastPos = 0
for (let i = 0; i < n; i++) {
const res = await generate({ prompt, negative_prompt: negPrompt })
lastPos = res.queue_position
}
setGenerating(false)
addPending(n)
addNotif('info', `Queued ${n} generation(s) at position ${lastPos}`, 3000)
} catch (err: unknown) {
setGenerating(false)
addNotif('error', err instanceof Error ? err.message : String(err))
document.title = 'ComfyUI Bot'
}
}
const queuePending = (status.comfy?.queue_pending ?? 0)
const queueRunning = (status.comfy?.queue_running ?? 0)
return (
<div className="max-w-2xl mx-auto space-y-6">
<div className="flex items-center justify-between">
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Generate</h1>
<div className="text-xs text-gray-400">
ComfyUI: {queueRunning} running, {queuePending} pending
</div>
</div>
{/* Mode toggle */}
<div className="flex flex-wrap gap-2 text-sm">
<button
onClick={() => setMode('workflow')}
className={`px-3 py-1.5 rounded ${mode === 'workflow' ? 'bg-blue-600 text-white' : 'bg-gray-200 dark:bg-gray-700 text-gray-700 dark:text-gray-300'}`}
>
Workflow mode
</button>
<button
onClick={() => setMode('prompt')}
className={`px-3 py-1.5 rounded ${mode === 'prompt' ? 'bg-blue-600 text-white' : 'bg-gray-200 dark:bg-gray-700 text-gray-700 dark:text-gray-300'}`}
>
Prompt mode
</button>
</div>
{/* Quick-load preset */}
{(presetsData?.presets ?? []).length > 0 && (
<div className="flex items-center gap-2">
<select
defaultValue=""
disabled={loadingPreset}
onChange={e => {
const v = e.target.value
e.target.value = ''
if (v) handlePresetLoad(v)
}}
className="flex-1 border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-700 dark:text-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-500 disabled:opacity-50"
>
<option value="" disabled>Load a preset</option>
{(presetsData?.presets ?? []).map(p => (
<option key={p.name} value={p.name}>
{p.name}{p.description ? `${p.description}` : ''}
</option>
))}
</select>
{loadingPreset && <span className="text-xs text-gray-400">Loading</span>}
</div>
)}
{/* Progress banner */}
{(pendingCount > 0 || executingNodeDisplay) && (
<div className="bg-blue-50 dark:bg-blue-900/30 border border-blue-200 dark:border-blue-700 rounded p-3 text-sm text-blue-700 dark:text-blue-300">
{pendingCount > 0 && `${pendingCount} generation(s) in progress`}
{executingNodeDisplay && ` · running: ${executingNodeDisplay}`}
</div>
)}
{/* Form */}
{mode === 'workflow' ? (
<DynamicWorkflowForm
onGenerate={handleWorkflowGenerate}
lastSeed={lastSeed}
generating={generating}
userLabel={user?.label}
/>
) : (
<div className="space-y-4">
<div>
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-1">Prompt</label>
<textarea
rows={3}
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 resize-y focus:outline-none focus:ring-2 focus:ring-blue-500"
value={prompt}
onChange={e => setPrompt(e.target.value)}
placeholder="Describe what you want to generate"
/>
</div>
<div>
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-1">Negative prompt</label>
<textarea
rows={2}
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 resize-y focus:outline-none focus:ring-2 focus:ring-blue-500"
value={negPrompt}
onChange={e => setNegPrompt(e.target.value)}
placeholder="What to avoid"
/>
</div>
<div className="flex gap-2 items-center">
<input
type="number"
min={1}
max={20}
value={promptCount}
onChange={e => setPromptCount(Math.max(1, Math.min(20, Number(e.target.value))))}
className="w-16 border border-gray-300 dark:border-gray-600 rounded px-2 py-2 text-sm text-center bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
title="Number of generations to queue"
/>
<button
onClick={handlePromptGenerate}
disabled={generating || !prompt.trim()}
className="flex-1 bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-4 py-2 text-sm font-semibold transition-colors"
>
{generating ? '⏳ Queuing…' : promptCount > 1 ? `Generate ×${promptCount}` : 'Generate'}
</button>
</div>
</div>
)}
{/* Toast notification stack */}
<div className="fixed bottom-4 right-4 z-50 space-y-2 max-w-sm">
{notifications.map(n => (
<div
key={n.id}
className={`flex items-start gap-2 rounded border p-3 text-sm shadow-md ${
n.type === 'success'
? 'bg-green-50 dark:bg-green-900/40 border-green-200 dark:border-green-700 text-green-700 dark:text-green-300'
: n.type === 'error'
? 'bg-red-50 dark:bg-red-900/40 border-red-200 dark:border-red-700 text-red-700 dark:text-red-300'
: 'bg-blue-50 dark:bg-blue-900/40 border-blue-200 dark:border-blue-700 text-blue-700 dark:text-blue-300'
}`}
>
<span className="flex-1">{n.msg}</span>
<button
onClick={() => dismissNotif(n.id)}
className="flex-none opacity-60 hover:opacity-100 leading-none"
aria-label="Dismiss"
>
</button>
</div>
))}
</div>
</div>
)
}

View File

@@ -0,0 +1,382 @@
import React, { useState, useEffect, useRef } from 'react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { getHistory, createHistoryShare, revokeHistoryShare, savePresetFromHistory } from '../api/client'
import { useAuth } from '../hooks/useAuth'
interface HistoryRow {
id: number
prompt_id: string
source: string
user_label?: string
overrides: Record<string, unknown>
seed?: number
file_paths?: string[]
created_at: string
share_token?: string | null
}
/** Debounce a value by `delay` ms. */
function useDebounce<T>(value: T, delay: number): T {
const [debounced, setDebounced] = useState(value)
useEffect(() => {
const t = setTimeout(() => setDebounced(value), delay)
return () => clearTimeout(t)
}, [value, delay])
return debounced
}
export default function HistoryPage() {
const [lightbox, setLightbox] = useState<string | null>(null)
const [expandedId, setExpandedId] = useState<string | null>(null)
const [searchInput, setSearchInput] = useState('')
const { user } = useAuth()
const debouncedQ = useDebounce(searchInput.trim(), 300)
const { data, isLoading } = useQuery({
queryKey: ['history', debouncedQ],
queryFn: () => getHistory(debouncedQ || undefined),
refetchInterval: 10_000,
})
const rows: HistoryRow[] = ((data?.history ?? []) as unknown as HistoryRow[])
const isSearching = searchInput.trim() !== debouncedQ
if (isLoading && !searchInput) return <div className="text-sm text-gray-400">Loading history</div>
return (
<div className="space-y-4">
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">History</h1>
{/* Search */}
<div className="relative">
<input
type="text"
value={searchInput}
onChange={e => setSearchInput(e.target.value)}
placeholder="Search by prompt, checkpoint, seed…"
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
{searchInput && (
<button
onClick={() => setSearchInput('')}
className="absolute right-2 top-1/2 -translate-y-1/2 text-gray-400 hover:text-gray-600 dark:hover:text-gray-300 text-xs px-1"
>
</button>
)}
</div>
{isSearching && <p className="text-xs text-gray-400">Searching</p>}
{isLoading ? (
<div className="text-sm text-gray-400">Loading</div>
) : rows.length === 0 ? (
<p className="text-sm text-gray-400">
{debouncedQ ? `No results for "${debouncedQ}".` : 'No generation history yet.'}
</p>
) : (
<>
{/* Desktop table — hidden on mobile */}
<div className="hidden sm:block overflow-x-auto">
<table className="w-full text-sm border-collapse">
<thead>
<tr className="text-left text-xs text-gray-400 border-b border-gray-200 dark:border-gray-700">
<th className="pb-2 pr-4">Time</th>
<th className="pb-2 pr-4">Source</th>
<th className="pb-2 pr-4">User</th>
<th className="pb-2 pr-4">Seed</th>
<th className="pb-2">Files</th>
</tr>
</thead>
<tbody>
{rows.map(row => (
<React.Fragment key={row.prompt_id}>
<tr
className="border-b border-gray-100 dark:border-gray-800 hover:bg-gray-50 dark:hover:bg-gray-800/50 cursor-pointer"
onClick={() => setExpandedId(expandedId === row.prompt_id ? null : row.prompt_id)}
>
<td className="py-2 pr-4 text-gray-500 dark:text-gray-400 whitespace-nowrap">
{new Date(row.created_at).toLocaleString()}
</td>
<td className="py-2 pr-4">
<span className={`text-xs px-1.5 py-0.5 rounded ${row.source === 'web' ? 'bg-blue-100 dark:bg-blue-900 text-blue-700 dark:text-blue-300' : 'bg-purple-100 dark:bg-purple-900 text-purple-700 dark:text-purple-300'}`}>
{row.source}
</span>
</td>
<td className="py-2 pr-4 text-gray-600 dark:text-gray-400">{row.user_label ?? '—'}</td>
<td className="py-2 pr-4 font-mono text-gray-700 dark:text-gray-300">{row.seed ?? '—'}</td>
<td className="py-2 text-gray-500 dark:text-gray-400">{(row.file_paths ?? []).length} file(s)</td>
</tr>
{expandedId === row.prompt_id && (
<tr className="bg-gray-50 dark:bg-gray-800/50">
<td colSpan={5} className="px-3 py-3">
<ExpandedRow row={row} onLightbox={setLightbox} currentUserLabel={user?.label ?? null} />
</td>
</tr>
)}
</React.Fragment>
))}
</tbody>
</table>
</div>
{/* Mobile card list — hidden on sm+ */}
<div className="sm:hidden space-y-2">
{rows.map(row => {
const isExpanded = expandedId === row.prompt_id
const dt = new Date(row.created_at)
const timeStr = dt.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })
const dateStr = dt.toLocaleDateString([], { month: 'short', day: 'numeric' })
return (
<div key={row.prompt_id} className="border border-gray-200 dark:border-gray-700 rounded bg-white dark:bg-gray-800">
<button
className="w-full text-left px-3 py-2 space-y-1"
onClick={() => setExpandedId(isExpanded ? null : row.prompt_id)}
>
<div className="flex items-center justify-between">
<span className="text-xs text-gray-500 dark:text-gray-400">{timeStr} · {dateStr}</span>
<span className={`text-xs px-1.5 py-0.5 rounded ${row.source === 'web' ? 'bg-blue-100 dark:bg-blue-900 text-blue-700 dark:text-blue-300' : 'bg-purple-100 dark:bg-purple-900 text-purple-700 dark:text-purple-300'}`}>
{row.source}
</span>
</div>
<div className="flex items-center gap-3 text-xs text-gray-600 dark:text-gray-400">
<span>User: {row.user_label ?? '—'}</span>
<span>Seed: {row.seed ?? '—'}</span>
<span>{(row.file_paths ?? []).length} file(s)</span>
<span className="ml-auto text-gray-400">{isExpanded ? '▲' : '▼'}</span>
</div>
</button>
{isExpanded && (
<div className="px-3 pb-3 border-t border-gray-100 dark:border-gray-700 pt-2">
<ExpandedRow row={row} onLightbox={setLightbox} currentUserLabel={user?.label ?? null} />
</div>
)}
</div>
)
})}
</div>
</>
)}
{/* Lightbox */}
{lightbox && (
<div
className="fixed inset-0 bg-black/80 flex items-center justify-center z-50"
onClick={() => setLightbox(null)}
>
<button
className="absolute top-3 right-3 text-white text-2xl leading-none hover:text-gray-300"
onClick={() => setLightbox(null)}
aria-label="Close"
>
</button>
<img src={lightbox} alt="preview" className="max-w-[90vw] max-h-[90vh] object-contain rounded" />
</div>
)}
</div>
)
}
function ExpandedRow({
row,
onLightbox,
currentUserLabel,
}: {
row: HistoryRow
onLightbox: (url: string) => void
currentUserLabel: string | null
}) {
const qc = useQueryClient()
const [copied, setCopied] = useState(false)
const [showSavePreset, setShowSavePreset] = useState(false)
const [presetName, setPresetName] = useState('')
const [presetDesc, setPresetDesc] = useState('')
const [presetMsg, setPresetMsg] = useState<{ ok: boolean; text: string } | null>(null)
const presetNameRef = useRef<HTMLInputElement>(null)
const { data: imagesData, isLoading } = useQuery({
queryKey: ['history', row.prompt_id, 'images'],
queryFn: async () => {
const res = await fetch(`/api/history/${row.prompt_id}/images`, { credentials: 'include' })
if (!res.ok) throw new Error(`${res.status}`)
return res.json() as Promise<{ images: Array<{ filename: string; data: string | null; mime_type: string }> }>
},
})
const shareMut = useMutation({
mutationFn: () => createHistoryShare(row.prompt_id),
onSuccess: () => qc.invalidateQueries({ queryKey: ['history'] }),
})
const revokeMut = useMutation({
mutationFn: () => revokeHistoryShare(row.prompt_id),
onSuccess: () => qc.invalidateQueries({ queryKey: ['history'] }),
})
const savePresetMut = useMutation({
mutationFn: () => savePresetFromHistory(row.prompt_id, presetName.trim(), presetDesc.trim() || undefined),
onSuccess: (res) => {
qc.invalidateQueries({ queryKey: ['presets'] })
setPresetMsg({ ok: true, text: `Saved as preset "${res.name}"` })
setPresetName('')
setPresetDesc('')
},
onError: (err) => {
setPresetMsg({ ok: false, text: err instanceof Error ? err.message : String(err) })
},
})
const shareUrl = row.share_token
? `${window.location.origin}/share/${row.share_token}`
: null
const isOwner = currentUserLabel !== null && row.user_label === currentUserLabel
const handleCopy = () => {
if (!shareUrl) return
navigator.clipboard.writeText(shareUrl).then(() => {
setCopied(true)
setTimeout(() => setCopied(false), 2000)
})
}
return (
<div className="space-y-3">
{/* Overrides */}
<details>
<summary className="text-xs text-gray-400 cursor-pointer">Overrides</summary>
<pre className="text-xs bg-gray-100 dark:bg-gray-700 rounded p-2 mt-1 overflow-auto max-h-32 text-gray-600 dark:text-gray-400">
{JSON.stringify(row.overrides, null, 2)}
</pre>
</details>
{/* Images */}
{isLoading ? (
<p className="text-xs text-gray-400">Loading images</p>
) : (imagesData?.images ?? []).length > 0 ? (
<div className="flex gap-2 flex-wrap">
{(imagesData?.images ?? []).map((img, i) => {
const isVideo = img.mime_type.startsWith('video/')
if (isVideo) {
const videoSrc = `/api/history/${row.prompt_id}/file/${encodeURIComponent(img.filename)}`
return (
<video
key={i}
src={videoSrc}
controls
className="rounded max-h-40"
/>
)
}
const src = `data:${img.mime_type};base64,${img.data}`
return (
<img
key={i}
src={src}
alt={img.filename}
className="rounded max-h-40 cursor-pointer border border-gray-200 dark:border-gray-700"
onClick={() => onLightbox(src)}
/>
)
})}
</div>
) : (
<p className="text-xs text-gray-400">Files not available (may have been moved or deleted).</p>
)}
{/* Owner-only actions */}
{isOwner && (
<div className="pt-1 border-t border-gray-100 dark:border-gray-700 space-y-2">
{/* Share section */}
{shareUrl ? (
<div className="space-y-1">
<p className="text-xs text-gray-500 dark:text-gray-400">Share link</p>
<div className="flex items-center gap-2 flex-wrap">
<code className="text-xs bg-gray-100 dark:bg-gray-700 rounded px-2 py-1 text-gray-700 dark:text-gray-300 break-all select-all flex-1 min-w-0">
{shareUrl}
</code>
<button
onClick={handleCopy}
className="shrink-0 text-xs px-2 py-1 rounded bg-blue-100 dark:bg-blue-900 text-blue-700 dark:text-blue-300 hover:bg-blue-200 dark:hover:bg-blue-800 transition-colors"
>
{copied ? 'Copied!' : 'Copy'}
</button>
<button
onClick={() => revokeMut.mutate()}
disabled={revokeMut.isPending}
className="shrink-0 text-xs px-2 py-1 rounded bg-red-100 dark:bg-red-900 text-red-700 dark:text-red-300 hover:bg-red-200 dark:hover:bg-red-800 disabled:opacity-50 transition-colors"
>
{revokeMut.isPending ? 'Revoking…' : 'Revoke'}
</button>
</div>
</div>
) : (
<button
onClick={() => shareMut.mutate()}
disabled={shareMut.isPending}
className="text-xs px-2 py-1 rounded bg-gray-100 dark:bg-gray-700 text-gray-600 dark:text-gray-400 hover:bg-gray-200 dark:hover:bg-gray-600 disabled:opacity-50 transition-colors"
>
{shareMut.isPending ? 'Creating link…' : 'Share'}
</button>
)}
{/* Save-as-preset section */}
{!showSavePreset ? (
<button
onClick={() => { setShowSavePreset(true); setPresetMsg(null); setTimeout(() => presetNameRef.current?.focus(), 50) }}
className="text-xs px-2 py-1 rounded bg-gray-100 dark:bg-gray-700 text-gray-600 dark:text-gray-400 hover:bg-gray-200 dark:hover:bg-gray-600 transition-colors"
>
Save as preset
</button>
) : (
<div className="space-y-1.5">
<p className="text-xs text-gray-500 dark:text-gray-400">Save overrides as preset</p>
<p className="text-xs text-amber-600 dark:text-amber-400">
Note: workflow template is not saved load it separately before using this preset.
</p>
<div className="flex gap-2 flex-wrap">
<input
ref={presetNameRef}
type="text"
value={presetName}
onChange={e => setPresetName(e.target.value)}
placeholder="Preset name"
className="flex-1 min-w-0 border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-xs bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
/>
<input
type="text"
value={presetDesc}
onChange={e => setPresetDesc(e.target.value)}
placeholder="Description (optional)"
className="flex-1 min-w-0 border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-xs bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
/>
</div>
<div className="flex gap-2">
<button
onClick={() => savePresetMut.mutate()}
disabled={!presetName.trim() || savePresetMut.isPending}
className="text-xs px-2 py-1 rounded bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white transition-colors"
>
{savePresetMut.isPending ? 'Saving…' : 'Save'}
</button>
<button
onClick={() => { setShowSavePreset(false); setPresetMsg(null); setPresetName(''); setPresetDesc('') }}
className="text-xs px-2 py-1 rounded bg-gray-100 dark:bg-gray-700 text-gray-600 dark:text-gray-400 hover:bg-gray-200 dark:hover:bg-gray-600 transition-colors"
>
Cancel
</button>
</div>
{presetMsg && (
<p className={`text-xs ${presetMsg.ok ? 'text-green-600 dark:text-green-400' : 'text-red-500'}`}>
{presetMsg.text}
</p>
)}
</div>
)}
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,216 @@
import React, { useRef, useState } from 'react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import {
listInputs,
uploadInput,
activateInput,
deleteInput,
getInputImage,
getInputThumb,
getInputMid,
getWorkflowInputs,
getState,
InputImage,
} from '../api/client'
import LazyImage from '../components/LazyImage'
export default function InputImagesPage() {
const qc = useQueryClient()
const fileRef = useRef<HTMLInputElement>(null)
const [uploading, setUploading] = useState(false)
const [uploadSlot, setUploadSlot] = useState<string | null>(null)
const [lightbox, setLightbox] = useState<string | null>(null)
const { data: images = [], isLoading } = useQuery({ queryKey: ['inputs'], queryFn: listInputs })
const { data: inputsData } = useQuery({ queryKey: ['workflow', 'inputs'], queryFn: getWorkflowInputs })
const { data: stateData } = useQuery({ queryKey: ['state'], queryFn: getState })
const imageSlots = inputsData
? [...inputsData.common, ...inputsData.advanced].filter(i => i.input_type === 'image')
: []
const deleteMut = useMutation({
mutationFn: (id: number) => deleteInput(id),
onSuccess: () => qc.invalidateQueries({ queryKey: ['inputs'] }),
})
const activateMut = useMutation({
mutationFn: ({ id, slotKey }: { id: number; slotKey: string }) => activateInput(id, slotKey),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ['inputs'] })
qc.invalidateQueries({ queryKey: ['state'] })
},
})
const handleUpload = async (slotKey: string) => {
if (!fileRef.current?.files?.length) return
setUploading(true)
setUploadSlot(slotKey)
try {
await uploadInput(fileRef.current.files[0], slotKey)
qc.invalidateQueries({ queryKey: ['inputs'] })
qc.invalidateQueries({ queryKey: ['state'] })
if (fileRef.current) fileRef.current.value = ''
} finally {
setUploading(false)
setUploadSlot(null)
}
}
const activeForSlot = (slotKey: string): string => {
const st = stateData as Record<string, unknown> | undefined
return String(st?.[slotKey] ?? '')
}
if (isLoading) return <div className="text-sm text-gray-400">Loading images</div>
return (
<div className="space-y-6">
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Input Images</h1>
{/* Per-slot sections */}
{imageSlots.length > 0 ? (
imageSlots.map(slot => {
const activeFilename = activeForSlot(slot.key)
return (
<section key={slot.key} className="space-y-2">
<div className="flex items-center justify-between">
<h2 className="text-sm font-semibold text-gray-700 dark:text-gray-300">
{slot.label}
{activeFilename && (
<span className="ml-2 text-xs font-normal text-blue-500">(active: {activeFilename})</span>
)}
</h2>
<label className="cursor-pointer text-xs bg-blue-600 hover:bg-blue-700 text-white rounded px-2 py-1">
{uploading && uploadSlot === slot.key ? 'Uploading…' : 'Upload'}
<input
ref={fileRef}
type="file"
accept="image/*"
className="hidden"
onChange={() => handleUpload(slot.key)}
/>
</label>
</div>
<ImageGrid
images={images}
activeFilename={activeFilename}
slotKey={slot.key}
onActivate={(id) => activateMut.mutate({ id, slotKey: slot.key })}
onDelete={(id) => deleteMut.mutate(id)}
onLightbox={setLightbox}
/>
</section>
)
})
) : (
/* No workflow loaded — show flat list */
<section className="space-y-2">
<div className="flex items-center justify-between">
<h2 className="text-sm font-semibold text-gray-700 dark:text-gray-300">All images</h2>
<label className="cursor-pointer text-xs bg-blue-600 hover:bg-blue-700 text-white rounded px-2 py-1">
{uploading ? 'Uploading…' : 'Upload'}
<input
ref={fileRef}
type="file"
accept="image/*"
className="hidden"
onChange={() => handleUpload('input_image')}
/>
</label>
</div>
<ImageGrid
images={images}
activeFilename=""
slotKey="input_image"
onActivate={(id) => activateMut.mutate({ id, slotKey: 'input_image' })}
onDelete={(id) => deleteMut.mutate(id)}
onLightbox={setLightbox}
/>
</section>
)}
{images.length === 0 && (
<p className="text-sm text-gray-400">No images yet. Upload one to get started.</p>
)}
{/* Lightbox */}
{lightbox && (
<div
className="fixed inset-0 bg-black/80 flex items-center justify-center z-50"
onClick={() => setLightbox(null)}
>
<button
className="absolute top-3 right-3 text-white text-2xl leading-none hover:text-gray-300"
onClick={() => setLightbox(null)}
aria-label="Close"
>
</button>
<img src={lightbox} alt="preview" className="max-w-[90vw] max-h-[90vh] object-contain rounded" />
</div>
)}
</div>
)
}
function ImageGrid({
images,
activeFilename,
slotKey,
onActivate,
onDelete,
onLightbox,
}: {
images: InputImage[]
activeFilename: string
slotKey: string
onActivate: (id: number) => void
onDelete: (id: number) => void
onLightbox: (url: string) => void
}) {
if (images.length === 0) return <p className="text-xs text-gray-400">No images.</p>
return (
<div className="grid grid-cols-3 sm:grid-cols-5 lg:grid-cols-8 gap-2">
{images.map(img => {
const isActive = img.filename === activeFilename
return (
<div
key={img.id}
className={`relative group aspect-square rounded border-2 overflow-hidden cursor-pointer ${
isActive ? 'border-blue-500' : 'border-transparent'
}`}
>
<LazyImage
thumbSrc={getInputThumb(img.id)}
midSrc={getInputMid(img.id)}
fullSrc={getInputImage(img.id)}
alt={img.filename}
className="w-full h-full"
onClick={() => onLightbox(getInputImage(img.id))}
/>
{isActive && (
<div className="absolute top-0.5 left-0.5 bg-blue-500 text-white text-[9px] px-1 rounded">active</div>
)}
<div className="absolute inset-0 bg-black/60 opacity-0 group-hover:opacity-100 [@media(hover:none)]:opacity-100 transition-opacity flex flex-col items-center justify-center gap-1">
{!isActive && (
<button
onClick={() => onActivate(img.id)}
className="text-[10px] bg-blue-600 text-white rounded px-1.5 py-0.5 hover:bg-blue-700"
>
Activate
</button>
)}
<button
onClick={() => onDelete(img.id)}
className="text-[10px] bg-red-600 text-white rounded px-1.5 py-0.5 hover:bg-red-700"
>
Delete
</button>
</div>
</div>
)
})}
</div>
)
}

View File

@@ -0,0 +1,76 @@
import React, { useState } from 'react'
import { useNavigate } from 'react-router-dom'
import { authLogin, adminLogin } from '../api/client'
import { useQueryClient } from '@tanstack/react-query'
export default function LoginPage() {
const navigate = useNavigate()
const qc = useQueryClient()
const [token, setToken] = useState('')
const [error, setError] = useState('')
const [loading, setLoading] = useState(false)
const [isAdmin, setIsAdmin] = useState(false)
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault()
setError('')
setLoading(true)
try {
if (isAdmin) {
await adminLogin(token)
} else {
await authLogin(token)
}
await qc.invalidateQueries({ queryKey: ['auth', 'me'] })
navigate('/generate')
} catch (err: unknown) {
const msg = err instanceof Error ? err.message : String(err)
if (msg.includes('429')) {
setError('Too many failed attempts. Please wait 1 hour before trying again.')
} else if (msg.includes('401')) {
setError('Invalid token or password.')
} else {
setError(msg)
}
} finally {
setLoading(false)
}
}
return (
<div className="min-h-screen flex items-center justify-center bg-gray-100 dark:bg-gray-900">
<div className="bg-white dark:bg-gray-800 shadow rounded-lg p-8 w-full max-w-sm">
<h1 className="text-xl font-bold mb-6 text-gray-800 dark:text-gray-100">ComfyUI Bot</h1>
<form onSubmit={handleSubmit} className="space-y-4">
<div>
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-1">
{isAdmin ? 'Admin password' : 'Invite token'}
</label>
<input
type="password"
value={token}
onChange={e => setToken(e.target.value)}
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
placeholder={isAdmin ? 'Password' : 'Paste your invite token'}
required
/>
</div>
{error && <p className="text-red-500 text-sm">{error}</p>}
<button
type="submit"
disabled={loading}
className="w-full bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-4 py-2 text-sm font-medium transition-colors"
>
{loading ? 'Logging in…' : 'Log in'}
</button>
</form>
<button
onClick={() => { setIsAdmin(a => !a); setError('') }}
className="mt-4 text-xs text-gray-400 hover:text-gray-600 dark:hover:text-gray-200 w-full text-center"
>
{isAdmin ? 'Use invite token instead' : 'Admin login'}
</button>
</div>
</div>
)
}

View File

@@ -0,0 +1,196 @@
import React, { useState } from 'react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { listPresets, savePreset, getPreset, loadPreset, deletePreset, PresetMeta } from '../api/client'
export default function PresetsPage() {
const qc = useQueryClient()
const [newName, setNewName] = useState('')
const [newDescription, setNewDescription] = useState('')
const [savingError, setSavingError] = useState('')
const [expanded, setExpanded] = useState<string | null>(null)
const [message, setMessage] = useState<string | null>(null)
const { data, isLoading } = useQuery({ queryKey: ['presets'], queryFn: listPresets })
const { data: presetDetail } = useQuery({
queryKey: ['preset', expanded],
queryFn: () => getPreset(expanded!),
enabled: !!expanded,
})
const saveMut = useMutation({
mutationFn: ({ name, description }: { name: string; description: string }) =>
savePreset(name, description || undefined),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ['presets'] })
setNewName('')
setNewDescription('')
setSavingError('')
setMessage('Preset saved.')
},
onError: (err) => setSavingError(err instanceof Error ? err.message : String(err)),
})
const loadMut = useMutation({
mutationFn: (name: string) => loadPreset(name),
onSuccess: (_, name) => {
qc.invalidateQueries({ queryKey: ['state'] })
qc.invalidateQueries({ queryKey: ['workflowInputs'] })
setMessage(`Loaded preset: ${name}`)
},
})
const deleteMut = useMutation({
mutationFn: (name: string) => deletePreset(name),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ['presets'] })
setMessage('Preset deleted.')
},
})
return (
<div className="max-w-xl mx-auto space-y-6">
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Presets</h1>
{message && (
<div className="text-sm text-blue-600 dark:text-blue-400">{message}</div>
)}
{/* Save current state */}
<div className="space-y-2">
<div className="flex flex-col sm:flex-row gap-2">
<input
type="text"
value={newName}
onChange={e => setNewName(e.target.value)}
placeholder="Preset name"
className="flex-1 border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
<button
onClick={() => { setSavingError(''); saveMut.mutate({ name: newName, description: newDescription }) }}
disabled={!newName.trim() || saveMut.isPending}
className="bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-3 py-2 text-sm font-medium"
>
Save current state
</button>
</div>
<input
type="text"
value={newDescription}
onChange={e => setNewDescription(e.target.value)}
placeholder="Description (optional)"
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
</div>
{savingError && <p className="text-red-500 text-sm">{savingError}</p>}
{/* Preset list */}
{isLoading ? (
<div className="text-sm text-gray-400">Loading presets</div>
) : (data?.presets ?? []).length === 0 ? (
<p className="text-sm text-gray-400">No presets saved yet.</p>
) : (
<ul className="divide-y divide-gray-200 dark:divide-gray-700 border border-gray-200 dark:border-gray-700 rounded">
{(data?.presets ?? []).map((preset: PresetMeta) => (
<li key={preset.name} className="px-3 py-2 space-y-1">
<div className="flex items-center justify-between">
<button
onClick={() => setExpanded(expanded === preset.name ? null : preset.name)}
className="text-sm font-medium text-gray-700 dark:text-gray-300 hover:text-blue-600 dark:hover:text-blue-400 text-left"
>
{preset.name}
{preset.owner && (
<span className="ml-2 text-xs text-gray-400 dark:text-gray-500 font-normal">
{preset.owner}
</span>
)}
</button>
<div className="flex gap-2">
<button
onClick={() => { setMessage(null); loadMut.mutate(preset.name) }}
disabled={loadMut.isPending}
className="text-xs bg-green-600 hover:bg-green-700 disabled:opacity-50 text-white rounded px-2 py-1"
>
Load
</button>
<button
onClick={() => { setMessage(null); deleteMut.mutate(preset.name) }}
className="text-xs bg-red-600 hover:bg-red-700 text-white rounded px-2 py-1"
>
Delete
</button>
</div>
</div>
{preset.description && (
<p className="text-xs italic text-gray-400 dark:text-gray-500">{preset.description}</p>
)}
{expanded === preset.name && presetDetail && (
<PresetDetail data={presetDetail} />
)}
</li>
))}
</ul>
)}
</div>
)
}
function PresetDetail({ data }: { data: Record<string, unknown> }) {
const state = (data.state ?? {}) as Record<string, unknown>
const { prompt, negative_prompt, seed, ...otherOverrides } = state
const hasOther = Object.keys(otherOverrides).length > 0
return (
<div className="mt-2 space-y-2 text-xs">
{prompt != null && (
<div>
<span className="font-semibold text-gray-600 dark:text-gray-400">Prompt</span>
<p className="mt-0.5 bg-gray-50 dark:bg-gray-800 rounded p-2 text-gray-700 dark:text-gray-300 whitespace-pre-wrap break-words">
{String(prompt)}
</p>
</div>
)}
{negative_prompt != null && (
<div>
<span className="font-semibold text-gray-600 dark:text-gray-400">Negative prompt</span>
<p className="mt-0.5 bg-gray-50 dark:bg-gray-800 rounded p-2 text-gray-700 dark:text-gray-300 whitespace-pre-wrap break-words">
{String(negative_prompt)}
</p>
</div>
)}
{seed != null && (
<div className="flex gap-2 items-baseline">
<span className="font-semibold text-gray-600 dark:text-gray-400">Seed</span>
<span className="font-mono text-gray-700 dark:text-gray-300">
{String(seed)}
{seed === -1 && <span className="ml-1 text-gray-400">(random)</span>}
</span>
</div>
)}
{hasOther && (
<div>
<span className="font-semibold text-gray-600 dark:text-gray-400">Other overrides</span>
<table className="mt-0.5 w-full text-xs border-collapse">
<tbody>
{Object.entries(otherOverrides).map(([k, v]) => (
<tr key={k} className="border-b border-gray-100 dark:border-gray-700">
<td className="py-0.5 pr-3 font-mono text-gray-500 dark:text-gray-400 whitespace-nowrap">{k}</td>
<td className="py-0.5 text-gray-700 dark:text-gray-300 break-all">{JSON.stringify(v)}</td>
</tr>
))}
</tbody>
</table>
</div>
)}
{!!data.workflow && (
<div className="text-green-600 dark:text-green-400">Includes workflow template</div>
)}
<details>
<summary className="cursor-pointer text-gray-400 hover:text-gray-600 dark:hover:text-gray-300">Raw JSON</summary>
<pre className="mt-1 text-xs bg-gray-100 dark:bg-gray-800 rounded p-2 overflow-auto max-h-48 text-gray-600 dark:text-gray-400">
{JSON.stringify(data, null, 2)}
</pre>
</details>
</div>
)
}

View File

@@ -0,0 +1,100 @@
import React, { useEffect, useRef, useState } from 'react'
import { useQuery, useMutation } from '@tanstack/react-query'
import { getServerStatus, serverAction, tailLogs } from '../api/client'
const ACTIONS = ['start', 'stop', 'restart'] as const
export default function ServerPage() {
const [actionMsg, setActionMsg] = useState<string | null>(null)
const logRef = useRef<HTMLPreElement>(null)
const { data: srv, refetch: refetchStatus } = useQuery({
queryKey: ['server', 'status'],
queryFn: getServerStatus,
refetchInterval: 5000,
})
const { data: logsData, refetch: refetchLogs } = useQuery({
queryKey: ['logs'],
queryFn: () => tailLogs(200),
refetchInterval: 2000,
})
// Auto-scroll log to bottom
useEffect(() => {
if (logRef.current) {
logRef.current.scrollTop = logRef.current.scrollHeight
}
}, [logsData])
const actionMut = useMutation({
mutationFn: (action: string) => serverAction(action),
onSuccess: (_, action) => {
setActionMsg(`${action} sent.`)
setTimeout(() => { setActionMsg(null); refetchStatus() }, 2000)
},
onError: (err) => setActionMsg(`Error: ${err instanceof Error ? err.message : String(err)}`),
})
const stateColor = srv?.service_state === 'SERVICE_RUNNING'
? 'text-green-600 dark:text-green-400'
: srv?.service_state === 'SERVICE_STOPPED'
? 'text-red-500 dark:text-red-400'
: 'text-yellow-500'
return (
<div className="max-w-2xl mx-auto space-y-6">
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Server</h1>
{/* Status */}
<div className="bg-white dark:bg-gray-800 rounded border border-gray-200 dark:border-gray-700 p-4 text-sm space-y-1">
<p>
State: <span className={`font-medium ${stateColor}`}>{srv?.service_state ?? '—'}</span>
</p>
<p>
HTTP: <span className={srv?.http_reachable ? 'text-green-600 dark:text-green-400' : 'text-red-500'}>
{srv == null ? '—' : srv.http_reachable ? '✅ reachable' : '❌ unreachable'}
</span>
</p>
</div>
{/* Controls */}
<div className="flex gap-2">
{ACTIONS.map(a => (
<button
key={a}
onClick={() => { setActionMsg(null); actionMut.mutate(a) }}
disabled={actionMut.isPending}
className={`text-sm rounded px-3 py-2 font-medium disabled:opacity-50 transition-colors ${
a === 'stop' ? 'bg-red-600 hover:bg-red-700 text-white'
: a === 'restart' ? 'bg-yellow-500 hover:bg-yellow-600 text-white'
: 'bg-green-600 hover:bg-green-700 text-white'
}`}
>
{a.charAt(0).toUpperCase() + a.slice(1)}
</button>
))}
</div>
{actionMsg && <p className="text-sm text-blue-600 dark:text-blue-400">{actionMsg}</p>}
{/* Log tail */}
<div className="space-y-1">
<div className="flex items-center justify-between">
<p className="text-sm font-medium text-gray-700 dark:text-gray-300">Log tail (last 200 lines)</p>
<button
onClick={() => refetchLogs()}
className="text-xs text-gray-400 hover:text-gray-600 dark:hover:text-gray-200"
>
Refresh
</button>
</div>
<pre
ref={logRef}
className="bg-gray-900 text-gray-100 text-xs rounded p-3 h-72 overflow-y-auto whitespace-pre-wrap font-mono"
>
{(logsData?.lines ?? []).join('\n') || 'No log lines available.'}
</pre>
</div>
</div>
)
}

View File

@@ -0,0 +1,119 @@
import React from 'react'
import { useParams, Link } from 'react-router-dom'
import { useQuery } from '@tanstack/react-query'
import { getShareFileUrl } from '../api/client'
interface ShareData {
prompt_id: string
created_at: string
overrides: Record<string, unknown>
seed?: number
images: Array<{ filename: string; data: string | null; mime_type: string }>
}
export default function SharePage() {
const { token } = useParams<{ token: string }>()
const { data, isLoading, error } = useQuery({
queryKey: ['share', token],
queryFn: async () => {
const res = await fetch(`/api/share/${token}`, { credentials: 'include' })
if (!res.ok) {
const msg = await res.text().catch(() => res.statusText)
throw Object.assign(new Error(msg), { status: res.status })
}
return res.json() as Promise<ShareData>
},
retry: false,
})
if (isLoading) {
return (
<div className="min-h-screen flex items-center justify-center bg-gray-100 dark:bg-gray-900">
<p className="text-sm text-gray-400">Loading</p>
</div>
)
}
const status = (error as any)?.status
if (status === 401) {
return (
<div className="min-h-screen flex items-center justify-center bg-gray-100 dark:bg-gray-900">
<div className="bg-white dark:bg-gray-800 shadow rounded-lg p-8 w-full max-w-sm text-center space-y-4">
<p className="text-gray-700 dark:text-gray-300">You need to be logged in to view this shared link.</p>
<Link
to="/login"
className="inline-block bg-blue-600 hover:bg-blue-700 text-white rounded px-4 py-2 text-sm font-medium transition-colors"
>
Log in
</Link>
</div>
</div>
)
}
if (error) {
return (
<div className="min-h-screen flex items-center justify-center bg-gray-100 dark:bg-gray-900">
<div className="bg-white dark:bg-gray-800 shadow rounded-lg p-8 w-full max-w-sm text-center">
<p className="text-gray-700 dark:text-gray-300">This share link has been revoked or does not exist.</p>
</div>
</div>
)
}
return (
<div className="min-h-screen bg-gray-100 dark:bg-gray-900 py-8 px-4">
<div className="max-w-3xl mx-auto space-y-6">
<div className="bg-white dark:bg-gray-800 shadow rounded-lg p-6 space-y-4">
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Shared Generation</h1>
<div className="text-sm text-gray-500 dark:text-gray-400 space-y-1">
<p>Generated: {data && new Date(data.created_at).toLocaleString()}</p>
{data?.seed != null && (
<p>Seed: <span className="font-mono">{data.seed}</span></p>
)}
</div>
{data && Object.keys(data.overrides).length > 0 && (
<details>
<summary className="text-xs text-gray-400 cursor-pointer">Overrides</summary>
<pre className="text-xs bg-gray-100 dark:bg-gray-700 rounded p-2 mt-1 overflow-auto max-h-40 text-gray-600 dark:text-gray-400">
{JSON.stringify(data.overrides, null, 2)}
</pre>
</details>
)}
{/* Images / Videos */}
<div className="flex gap-3 flex-wrap">
{(data?.images ?? []).map((img, i) => {
if (img.mime_type.startsWith('video/')) {
return (
<video
key={i}
src={getShareFileUrl(token!, img.filename)}
controls
className="rounded max-h-80 max-w-full"
/>
)
}
return (
<img
key={i}
src={`data:${img.mime_type};base64,${img.data}`}
alt={img.filename}
className="rounded max-h-80 max-w-full object-contain border border-gray-200 dark:border-gray-700"
/>
)
})}
</div>
</div>
<p className="text-center text-xs text-gray-400">
<Link to="/login" className="hover:underline">ComfyUI Bot</Link>
</p>
</div>
</div>
)
}

View File

@@ -0,0 +1,77 @@
import React from 'react'
import { useAuth } from '../hooks/useAuth'
import { useStatus } from '../hooks/useStatus'
export default function StatusPage() {
const { user } = useAuth()
const { status, executingNode } = useStatus({ enabled: !!user })
const { bot, comfy, service, upload } = status
return (
<div className="max-w-2xl mx-auto space-y-6">
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Status</h1>
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4">
{/* Bot */}
<Card title="Bot">
<Row label="Latency" value={bot ? `${bot.latency_ms} ms` : '—'} />
<Row label="Uptime" value={bot?.uptime ?? '—'} />
</Card>
{/* ComfyUI */}
<Card title="ComfyUI">
<Row label="Server" value={comfy?.server ?? '—'} />
<Row
label="Reachable"
value={comfy?.reachable == null ? '—' : comfy.reachable ? '✅ yes' : '❌ no'}
/>
<Row label="Queue running" value={String(comfy?.queue_running ?? 0)} />
<Row label="Queue pending" value={String(comfy?.queue_pending ?? 0)} />
<Row label="Workflow loaded" value={comfy?.workflow_loaded ? '✓' : '✗'} />
<Row label="Last seed" value={comfy?.last_seed != null ? String(comfy.last_seed) : '—'} />
<Row label="Total generated" value={String(comfy?.total_generated ?? 0)} />
</Card>
{/* Service */}
<Card title="Service">
<Row label="State" value={service?.state ?? '—'} />
</Card>
{/* Auto-upload */}
<Card title="Auto-upload">
<Row label="Configured" value={upload?.configured ? '✓' : '✗'} />
<Row label="Running" value={upload?.running ? '⏳ yes' : 'idle'} />
<Row label="Total ok" value={String(upload?.total_ok ?? 0)} />
<Row label="Total fail" value={String(upload?.total_fail ?? 0)} />
</Card>
</div>
{/* Executing node */}
{executingNode && (
<div className="bg-blue-50 dark:bg-blue-900/30 border border-blue-200 dark:border-blue-700 rounded p-3 text-sm text-blue-700 dark:text-blue-300">
Executing node: <strong>{executingNode}</strong>
</div>
)}
</div>
)
}
function Card({ title, children }: { title: string; children: React.ReactNode }) {
return (
<div className="bg-white dark:bg-gray-800 rounded border border-gray-200 dark:border-gray-700 p-4">
<p className="text-xs font-semibold uppercase tracking-wide text-gray-400 dark:text-gray-500 mb-2">{title}</p>
<dl className="space-y-1">{children}</dl>
</div>
)
}
function Row({ label, value }: { label: string; value: string }) {
return (
<div className="flex justify-between text-sm">
<dt className="text-gray-500 dark:text-gray-400">{label}</dt>
<dd className="text-gray-800 dark:text-gray-200 font-mono text-right">{value}</dd>
</div>
)
}

View File

@@ -0,0 +1,151 @@
import React, { useRef, useState } from 'react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import {
getWorkflow,
getWorkflowInputs,
listWorkflowFiles,
uploadWorkflow,
loadWorkflow,
} from '../api/client'
export default function WorkflowPage() {
const qc = useQueryClient()
const fileRef = useRef<HTMLInputElement>(null)
const [uploading, setUploading] = useState(false)
const [loadingFile, setLoadingFile] = useState<string | null>(null)
const [message, setMessage] = useState<string | null>(null)
const { data: wf } = useQuery({ queryKey: ['workflow'], queryFn: getWorkflow })
const { data: inputs } = useQuery({
queryKey: ['workflow', 'inputs'],
queryFn: getWorkflowInputs,
enabled: wf?.loaded ?? false,
})
const { data: filesData, refetch: refetchFiles } = useQuery({
queryKey: ['workflow', 'files'],
queryFn: listWorkflowFiles,
})
const loadMut = useMutation({
mutationFn: (filename: string) => loadWorkflow(filename),
onSuccess: (_, filename) => {
setMessage(`Loaded: ${filename}`)
qc.invalidateQueries({ queryKey: ['workflow'] })
qc.invalidateQueries({ queryKey: ['state'] })
setLoadingFile(null)
},
onError: (err) => {
setMessage(`Error: ${err instanceof Error ? err.message : String(err)}`)
setLoadingFile(null)
},
})
const handleUpload = async () => {
if (!fileRef.current?.files?.length) return
setUploading(true)
setMessage(null)
try {
const res = await uploadWorkflow(fileRef.current.files[0])
setMessage(`Uploaded: ${res.filename ?? 'ok'}`)
refetchFiles()
if (fileRef.current) fileRef.current.value = ''
} catch (err) {
setMessage(`Upload error: ${err instanceof Error ? err.message : String(err)}`)
} finally {
setUploading(false)
}
}
const allInputs = [...(inputs?.common ?? []), ...(inputs?.advanced ?? [])]
return (
<div className="max-w-2xl mx-auto space-y-6">
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Workflow</h1>
{/* Current workflow */}
<div className="bg-white dark:bg-gray-800 rounded border border-gray-200 dark:border-gray-700 p-4 text-sm space-y-1">
<p className="font-medium text-gray-700 dark:text-gray-300">Current workflow</p>
{wf?.loaded ? (
<>
<p className="text-gray-500 dark:text-gray-400">{wf.last_workflow_file ?? '(loaded from state)'}</p>
<p className="text-gray-500 dark:text-gray-400">{wf.node_count} node(s) detected</p>
</>
) : (
<p className="text-gray-400">No workflow loaded</p>
)}
</div>
{message && (
<div className="text-sm text-blue-600 dark:text-blue-400">{message}</div>
)}
{/* Upload */}
<div className="flex items-center gap-3">
<label className="cursor-pointer text-sm bg-blue-600 hover:bg-blue-700 text-white rounded px-3 py-2">
{uploading ? 'Uploading…' : 'Upload workflow JSON'}
<input
ref={fileRef}
type="file"
accept=".json"
className="hidden"
onChange={handleUpload}
/>
</label>
<span className="text-xs text-gray-400">Uploads to workflows/ folder</span>
</div>
{/* Available files */}
{(filesData?.files ?? []).length > 0 && (
<div className="space-y-2">
<p className="text-sm font-medium text-gray-700 dark:text-gray-300">Available workflows</p>
<ul className="divide-y divide-gray-200 dark:divide-gray-700 border border-gray-200 dark:border-gray-700 rounded">
{(filesData?.files ?? []).map(f => (
<li key={f} className="flex items-center justify-between px-3 py-2 text-sm">
<span className={`text-gray-700 dark:text-gray-300 ${wf?.last_workflow_file === f ? 'font-semibold text-blue-600 dark:text-blue-400' : ''}`}>
{f}
{wf?.last_workflow_file === f && <span className="ml-1 text-xs">(active)</span>}
</span>
<button
onClick={() => { setLoadingFile(f); setMessage(null); loadMut.mutate(f) }}
disabled={loadingFile === f}
className="text-xs bg-gray-200 dark:bg-gray-700 hover:bg-gray-300 dark:hover:bg-gray-600 rounded px-2 py-1 disabled:opacity-50"
>
{loadingFile === f ? 'Loading…' : 'Load'}
</button>
</li>
))}
</ul>
</div>
)}
{/* Discovered inputs summary */}
{inputs && allInputs.length > 0 && (
<div className="space-y-2">
<p className="text-sm font-medium text-gray-700 dark:text-gray-300">Discovered inputs ({allInputs.length})</p>
<div className="overflow-x-auto">
<table className="w-full text-xs border-collapse border border-gray-200 dark:border-gray-700 rounded">
<thead>
<tr className="bg-gray-100 dark:bg-gray-700">
<th className="text-left p-2 border-b border-gray-200 dark:border-gray-600">Key</th>
<th className="text-left p-2 border-b border-gray-200 dark:border-gray-600">Label</th>
<th className="text-left p-2 border-b border-gray-200 dark:border-gray-600">Type</th>
<th className="text-left p-2 border-b border-gray-200 dark:border-gray-600">Common</th>
</tr>
</thead>
<tbody>
{allInputs.map(inp => (
<tr key={inp.key} className="border-b border-gray-100 dark:border-gray-700 last:border-0">
<td className="p-2 font-mono text-gray-600 dark:text-gray-400">{inp.key}</td>
<td className="p-2 text-gray-700 dark:text-gray-300">{inp.label}</td>
<td className="p-2 text-gray-500 dark:text-gray-400">{inp.input_type}</td>
<td className="p-2 text-gray-500 dark:text-gray-400">{inp.is_common ? '✓' : ''}</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,9 @@
/** @type {import('tailwindcss').Config} */
export default {
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
darkMode: 'class',
theme: {
extend: {},
},
plugins: [],
}

24
frontend/tsconfig.json Normal file
View File

@@ -0,0 +1,24 @@
{
"compilerOptions": {
"target": "ES2020",
"useDefineForClassFields": true,
"lib": ["ES2020", "DOM", "DOM.Iterable"],
"module": "ESNext",
"skipLibCheck": true,
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"jsx": "react-jsx",
"strict": true,
"noUnusedLocals": false,
"noUnusedParameters": false,
"noFallthroughCasesInSwitch": true,
"baseUrl": ".",
"paths": {
"@/*": ["src/*"]
}
},
"include": ["src"]
}

25
frontend/vite.config.ts Normal file
View File

@@ -0,0 +1,25 @@
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
export default defineConfig({
plugins: [react()],
build: {
outDir: '../web-static',
emptyOutDir: true,
},
server: {
port: 5173,
proxy: {
'/api': {
target: 'http://localhost:8080',
changeOrigin: true,
secure: false,
},
'/ws': {
target: 'ws://localhost:8080',
ws: true,
changeOrigin: true,
},
},
},
})

322
generation_db.py Normal file
View File

@@ -0,0 +1,322 @@
"""
generation_db.py
================
SQLite persistence for ComfyUI generation history and output file blobs.
Two tables
----------
generation_history : one row per prompt submitted to ComfyUI
generation_files : one row per output file (image / video) as a BLOB
The module-level ``_DB_PATH`` is set by :func:`init_db`; all other
functions use that path so callers never need to pass it around.
"""
from __future__ import annotations
import json
import secrets
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
_DB_PATH: Path = Path(__file__).parent / "generation_history.db"
_SCHEMA = """
CREATE TABLE IF NOT EXISTS generation_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
prompt_id TEXT UNIQUE NOT NULL,
source TEXT NOT NULL,
user_label TEXT,
overrides TEXT,
seed INTEGER,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS generation_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
generation_id INTEGER NOT NULL REFERENCES generation_history(id),
filename TEXT NOT NULL,
file_data BLOB NOT NULL,
mime_type TEXT
);
CREATE TABLE IF NOT EXISTS generation_shares (
id INTEGER PRIMARY KEY AUTOINCREMENT,
share_token TEXT UNIQUE NOT NULL,
prompt_id TEXT NOT NULL,
owner_label TEXT NOT NULL,
created_at TEXT NOT NULL
);
"""
def _connect(db_path: Path | None = None) -> sqlite3.Connection:
path = db_path if db_path is not None else _DB_PATH
conn = sqlite3.connect(str(path), check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def _detect_mime(data: bytes) -> str:
"""Detect MIME type from magic bytes."""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return "image/png"
if data[:2] == b"\xff\xd8":
return "image/jpeg"
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP":
return "image/webp"
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"AVI ":
return "video/x-msvideo"
if len(data) >= 8 and data[4:8] == b"ftyp":
return "video/mp4"
if data[:4] == b"\x1aE\xdf\xa3": # EBML (WebM/MKV)
return "video/webm"
return "application/octet-stream"
def init_db(db_path: Path = _DB_PATH) -> None:
"""Create tables if they don't exist. Accepts a path for testability."""
global _DB_PATH
_DB_PATH = db_path
with _connect(db_path) as conn:
conn.executescript(_SCHEMA)
conn.commit()
def record_generation(
prompt_id: str,
source: str,
user_label: str | None,
overrides_dict: dict[str, Any] | None,
seed: int | None,
) -> int:
"""Insert a generation history row. Returns the auto-increment ``id``."""
overrides_json = json.dumps(overrides_dict) if overrides_dict is not None else None
created_at = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
cur = conn.execute(
"""
INSERT INTO generation_history
(prompt_id, source, user_label, overrides, seed, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(prompt_id, source, user_label, overrides_json, seed, created_at),
)
conn.commit()
return cur.lastrowid # type: ignore[return-value]
def record_file(generation_id: int, filename: str, file_data: bytes) -> None:
"""Insert a file BLOB row, auto-detecting MIME type from magic bytes."""
mime_type = _detect_mime(file_data)
with _connect() as conn:
conn.execute(
"""
INSERT INTO generation_files (generation_id, filename, file_data, mime_type)
VALUES (?, ?, ?, ?)
""",
(generation_id, filename, file_data, mime_type),
)
conn.commit()
def _rows_to_history(conn: sqlite3.Connection, rows) -> list[dict]:
"""Convert raw generation_history rows (with optional share_token) to dicts."""
result: list[dict] = []
for row in rows:
d = dict(row)
if d["overrides"]:
try:
d["overrides"] = json.loads(d["overrides"])
except (json.JSONDecodeError, TypeError):
d["overrides"] = {}
else:
d["overrides"] = {}
files = conn.execute(
"SELECT filename FROM generation_files WHERE generation_id = ?",
(d["id"],),
).fetchall()
d["file_paths"] = [f["filename"] for f in files]
result.append(d)
return result
def get_history(limit: int = 50) -> list[dict]:
"""Return recent generation rows (newest first) with a ``file_paths`` list."""
with _connect() as conn:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
s.share_token
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
ORDER BY h.id DESC LIMIT ?
""",
(limit,),
).fetchall()
return _rows_to_history(conn, rows)
def get_history_for_user(user_label: str, limit: int = 50) -> list[dict]:
"""Return recent generation rows for a specific user (newest first)."""
with _connect() as conn:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
s.share_token
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
WHERE h.user_label = ?
ORDER BY h.id DESC LIMIT ?
""",
(user_label, user_label, limit),
).fetchall()
return _rows_to_history(conn, rows)
def get_generation(prompt_id: str) -> dict | None:
"""Return the generation_history row for *prompt_id*, or None."""
with _connect() as conn:
row = conn.execute(
"SELECT id, prompt_id, user_label FROM generation_history WHERE prompt_id = ?",
(prompt_id,),
).fetchone()
return dict(row) if row else None
def get_generation_full(prompt_id: str) -> dict | None:
"""Return overrides (parsed) + seed for *prompt_id*, or None if not found."""
with _connect() as conn:
row = conn.execute(
"SELECT prompt_id, user_label, overrides, seed FROM generation_history WHERE prompt_id = ?",
(prompt_id,),
).fetchone()
if row is None:
return None
d = dict(row)
if d["overrides"]:
try:
d["overrides"] = json.loads(d["overrides"])
except (json.JSONDecodeError, TypeError):
d["overrides"] = {}
else:
d["overrides"] = {}
return d
def search_history_for_user(user_label: str, query: str, limit: int = 50) -> list[dict]:
"""Return history rows where the overrides JSON contains *query* (case-insensitive)."""
with _connect() as conn:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
s.share_token
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
WHERE h.user_label = ? AND LOWER(h.overrides) LIKE LOWER(?)
ORDER BY h.id DESC LIMIT ?
""",
(user_label, user_label, f"%{query}%", limit),
).fetchall()
return _rows_to_history(conn, rows)
def search_history(query: str, limit: int = 50) -> list[dict]:
"""Admin version: search all users' history for *query* in overrides JSON."""
with _connect() as conn:
rows = conn.execute(
"""
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
s.share_token
FROM generation_history h
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
WHERE LOWER(h.overrides) LIKE LOWER(?)
ORDER BY h.id DESC LIMIT ?
""",
(f"%{query}%", limit),
).fetchall()
return _rows_to_history(conn, rows)
def create_share(prompt_id: str, owner_label: str) -> str:
"""Create a share token for *prompt_id*. Idempotent — returns the same token if one exists."""
token = secrets.token_urlsafe(32)
created_at = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
conn.execute(
"""
INSERT OR IGNORE INTO generation_shares (share_token, prompt_id, owner_label, created_at)
VALUES (?, ?, ?, ?)
""",
(token, prompt_id, owner_label, created_at),
)
conn.commit()
row = conn.execute(
"SELECT share_token FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
(prompt_id, owner_label),
).fetchone()
return row["share_token"]
def revoke_share(prompt_id: str, owner_label: str) -> bool:
"""Delete the share token for *prompt_id*. Returns True if a row was deleted."""
with _connect() as conn:
cur = conn.execute(
"DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
(prompt_id, owner_label),
)
conn.commit()
return cur.rowcount > 0
def get_share_by_token(token: str) -> dict | None:
"""Return generation info for a share token, or None if not found/revoked."""
with _connect() as conn:
row = conn.execute(
"""
SELECT h.prompt_id, h.overrides, h.seed, h.created_at
FROM generation_shares s
JOIN generation_history h ON h.prompt_id = s.prompt_id
WHERE s.share_token = ?
""",
(token,),
).fetchone()
if row is None:
return None
d = dict(row)
if d["overrides"]:
try:
d["overrides"] = json.loads(d["overrides"])
except (json.JSONDecodeError, TypeError):
d["overrides"] = {}
else:
d["overrides"] = {}
return d
def get_files(prompt_id: str) -> list[dict]:
"""Return all output files for *prompt_id* as ``[{filename, data, mime_type}]``."""
with _connect() as conn:
gen_row = conn.execute(
"SELECT id FROM generation_history WHERE prompt_id = ?",
(prompt_id,),
).fetchone()
if not gen_row:
return []
files = conn.execute(
"SELECT filename, file_data, mime_type FROM generation_files WHERE generation_id = ?",
(gen_row["id"],),
).fetchall()
return [
{
"filename": f["filename"],
"data": bytes(f["file_data"]),
"mime_type": f["mime_type"],
}
for f in files
]

84
image_utils.py Normal file
View File

@@ -0,0 +1,84 @@
"""
image_utils.py
==============
Shared image processing utilities.
"""
from __future__ import annotations
import io
import logging
from pathlib import Path
try:
from PIL import Image as _PILImage
_HAS_PIL = True
except ImportError:
_HAS_PIL = False
logger = logging.getLogger(__name__)
DISCORD_MAX_BYTES = 8 * 1024 * 1024 # 8 MiB Discord free-tier upload limit
def compress_to_discord_limit(data: bytes, filename: str) -> tuple[bytes, str]:
"""
Compress image bytes to fit within Discord's 8 MiB upload limit.
Tries quality reduction first, then progressive downsizing.
Converts to JPEG if the source format (e.g. PNG) cannot be quality-compressed.
Returns (data, filename) — filename may change if format conversion is needed.
No-op if data is already within the limit.
"""
if len(data) <= DISCORD_MAX_BYTES:
logger.info(f"Image size is less than {DISCORD_MAX_BYTES}")
return data, filename
if not _HAS_PIL:
logger.warning("Pillow not installed — cannot compress %s (%d bytes), uploading as-is", filename, len(data))
return data, filename
suffix = Path(filename).suffix.lower()
stem = Path(filename).stem
img = _PILImage.open(io.BytesIO(data))
# PNG/GIF/BMP don't support lossy quality — convert to JPEG
logger.info("Checking file extension for convert to jpeg")
if suffix in (".jpg", ".jpeg"):
fmt, out_name = "JPEG", filename
elif suffix == ".webp":
fmt, out_name = "WEBP", filename
else:
fmt, out_name = "JPEG", stem + ".jpg"
if img.mode in ("RGBA", "P", "LA"):
img = img.convert("RGB")
logger.info("File extension checked")
# Round 1: quality reduction only
logger.info("# Round 1: quality reduction only")
for quality in (85, 70, 55, 40, 25, 15):
logger.info(f"# Round 1: Trying quality: {quality}")
buf = io.BytesIO()
img.save(buf, format=fmt, quality=quality)
if buf.tell() <= DISCORD_MAX_BYTES:
logger.info("Compressed %s at quality=%d: %d%d bytes", filename, quality, len(data), buf.tell())
return buf.getvalue(), out_name
# Round 2: resize + low quality
logger.info("# Round 2: resize + low quality")
for scale in (0.75, 0.5, 0.35, 0.25):
w, h = img.size
resized = img.resize((int(w * scale), int(h * scale)), _PILImage.LANCZOS)
logger.info(f"# Round 2: Trying to resize: {resized.size}")
buf = io.BytesIO()
resized.save(buf, format=fmt, quality=15)
if buf.tell() <= DISCORD_MAX_BYTES:
logger.info("Compressed %s at scale=%.2f: %d%d bytes", filename, scale, len(data), buf.tell())
return buf.getvalue(), out_name
logger.warning("Could not compress %s under %d bytes — uploading best effort", filename, DISCORD_MAX_BYTES)
buf = io.BytesIO()
img.save(buf, format=fmt, quality=10)
return buf.getvalue(), out_name

232
input_image_db.py Normal file
View File

@@ -0,0 +1,232 @@
"""
input_image_db.py
=================
SQLite helpers for tracking Discord-channel-backed input images.
The database stores one row per attachment, so a single message with
multiple images produces multiple rows. The stable lookup key is the
auto-increment `id`; `(original_message_id, filename)` is a unique
composite constraint.
"""
from __future__ import annotations
import sqlite3
from pathlib import Path
DB_PATH = Path(__file__).parent / "input_images.db"
_SCHEMA = """
CREATE TABLE IF NOT EXISTS input_images (
id INTEGER PRIMARY KEY AUTOINCREMENT,
original_message_id INTEGER NOT NULL,
bot_reply_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
filename TEXT NOT NULL,
is_active INTEGER NOT NULL DEFAULT 0,
image_data BLOB,
active_slot_key TEXT DEFAULT NULL,
UNIQUE(original_message_id, filename)
)
"""
# Live migrations applied on every startup — safe if column already exists
_MIGRATIONS = [
"ALTER TABLE input_images ADD COLUMN image_data BLOB",
"ALTER TABLE input_images ADD COLUMN active_slot_key TEXT DEFAULT NULL",
]
# Columns returned by get_image / get_all_images (excludes the potentially large BLOB)
_SAFE_COLS = "id, original_message_id, bot_reply_id, channel_id, filename, is_active, active_slot_key"
def _connect() -> sqlite3.Connection:
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
return conn
def init_db() -> None:
"""Create the input_images table if it does not exist, and run column migrations."""
with _connect() as conn:
conn.execute(_SCHEMA)
for stmt in _MIGRATIONS:
try:
conn.execute(stmt)
except sqlite3.OperationalError:
pass # column already exists
conn.commit()
def upsert_image(
original_message_id: int,
bot_reply_id: int,
channel_id: int,
filename: str,
image_data: bytes | None = None,
) -> int:
"""
Insert a new image record or update an existing one.
Returns the stable row ``id`` (used as the persistent view key).
When *image_data* is provided it is stored as a BLOB; on UPDATE it is
only overwritten when not None.
"""
with _connect() as conn:
existing = conn.execute(
"SELECT id FROM input_images WHERE original_message_id = ? AND filename = ?",
(original_message_id, filename),
).fetchone()
if existing:
if image_data is not None:
conn.execute(
"UPDATE input_images SET bot_reply_id = ?, channel_id = ?, image_data = ? WHERE id = ?",
(bot_reply_id, channel_id, image_data, existing["id"]),
)
else:
conn.execute(
"UPDATE input_images SET bot_reply_id = ?, channel_id = ? WHERE id = ?",
(bot_reply_id, channel_id, existing["id"]),
)
row_id = existing["id"]
else:
cur = conn.execute(
"""
INSERT INTO input_images (original_message_id, bot_reply_id, channel_id, filename, is_active, image_data)
VALUES (?, ?, ?, ?, 0, ?)
""",
(original_message_id, bot_reply_id, channel_id, filename, image_data),
)
row_id = cur.lastrowid
conn.commit()
return row_id
def get_image_data(row_id: int) -> bytes | None:
"""Return the raw image bytes for a row, or None if the row is missing or has no data."""
with _connect() as conn:
row = conn.execute(
"SELECT image_data FROM input_images WHERE id = ?",
(row_id,),
).fetchone()
if row is None:
return None
return row["image_data"]
def activate_image_for_slot(row_id: int, slot_key: str, comfy_input_path: str) -> str:
"""
Write the stored image bytes to ``{comfy_input_path}/ttb_{slot_key}{ext}``
and record the slot assignment in the DB.
Returns the basename of the written file (e.g. ``ttb_input_image.jpg``).
Raises ``ValueError`` if the row has no image_data (user must re-upload).
"""
data = get_image_data(row_id)
if data is None:
raise ValueError(
f"No image data stored for row {row_id}. Re-upload the image to backfill."
)
row = get_image(row_id)
if row is None:
raise ValueError(f"No DB record for row id {row_id}")
ext = Path(row["filename"]).suffix # e.g. ".jpg"
dest_name = f"ttb_{slot_key}{ext}"
input_path = Path(comfy_input_path)
input_path.mkdir(parents=True, exist_ok=True)
# Remove any existing file for this slot (may have a different extension)
for old in input_path.glob(f"ttb_{slot_key}.*"):
try:
old.unlink()
except Exception:
pass
(input_path / dest_name).write_bytes(data)
# Update DB: clear slot from previous holder, then assign to this row
with _connect() as conn:
conn.execute(
"UPDATE input_images SET active_slot_key = NULL WHERE active_slot_key = ?",
(slot_key,),
)
conn.execute(
"UPDATE input_images SET active_slot_key = ? WHERE id = ?",
(slot_key, row_id),
)
conn.commit()
return dest_name
def deactivate_image_slot(slot_key: str, comfy_input_path: str) -> None:
"""
Remove the ``ttb_{slot_key}.*`` file from the ComfyUI input folder and
clear the matching DB column. Safe no-op if nothing is active for that slot.
"""
input_path = Path(comfy_input_path)
for old in input_path.glob(f"ttb_{slot_key}.*"):
try:
old.unlink()
except Exception:
pass
with _connect() as conn:
conn.execute(
"UPDATE input_images SET active_slot_key = NULL WHERE active_slot_key = ?",
(slot_key,),
)
conn.commit()
def set_active(row_id: int) -> None:
"""Mark one image as active and clear the active flag on all others."""
with _connect() as conn:
conn.execute("UPDATE input_images SET is_active = 0")
conn.execute(
"UPDATE input_images SET is_active = 1 WHERE id = ?",
(row_id,),
)
conn.commit()
def get_image(row_id: int) -> dict | None:
"""Return a single image row by its auto-increment id (excluding image_data), or None."""
with _connect() as conn:
row = conn.execute(
f"SELECT {_SAFE_COLS} FROM input_images WHERE id = ?",
(row_id,),
).fetchone()
return dict(row) if row else None
def get_all_images() -> list[dict]:
"""Return all image rows as a list of dicts (excluding image_data)."""
with _connect() as conn:
rows = conn.execute(
f"SELECT {_SAFE_COLS} FROM input_images"
).fetchall()
return [dict(r) for r in rows]
def delete_image(row_id: int, comfy_input_path: str | None = None) -> None:
"""
Remove an image record from the database.
If the image is currently active for a slot and *comfy_input_path* is
provided, the corresponding ``ttb_{slot_key}.*`` file is also deleted.
"""
row = get_image(row_id)
if row and row.get("active_slot_key") and comfy_input_path:
deactivate_image_slot(row["active_slot_key"], comfy_input_path)
with _connect() as conn:
conn.execute(
"DELETE FROM input_images WHERE id = ?",
(row_id,),
)
conn.commit()

286
media_uploader.py Normal file
View File

@@ -0,0 +1,286 @@
"""
media_uploader.py
=================
Auto-uploads generated media files to the external storage server.
On success the local file is deleted. Any files that fail to upload
are left in place and will be retried automatically on the next call
(i.e. the next time a generation completes).
If no credentials are configured (MEDIA_UPLOAD_USER / MEDIA_UPLOAD_PASS
env vars not set), flush_pending() is a no-op and files are left for the
manual ttr!collect-videos command.
Upload behaviour:
- Files are categorised into image / gif / video / audio folders.
- A ``folder`` form field is sent with each upload so the server can
route the file into the correct subdirectory.
- The current datetime is appended to each filename before uploading
(e.g. ``output_20260225_143022.png``); the local filename is unchanged.
Usage::
from media_uploader import flush_pending, get_stats
await flush_pending(Path(config.comfy_output_path),
config.media_upload_user,
config.media_upload_pass)
stats = get_stats()
"""
from __future__ import annotations
import asyncio
import logging
import mimetypes
import ssl
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
import aiohttp
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
UPLOAD_URL = "https://mediaup.revoluxiant.ddns.net/upload"
# Media categories and their recognised extensions.
CATEGORY_EXTENSIONS: dict[str, frozenset[str]] = {
"image": frozenset({
".png", ".jpg", ".jpeg", ".webp", ".bmp",
".tiff", ".tif", ".avif", ".heic", ".heif", ".svg", ".ico",
}),
"gif": frozenset({
".gif",
}),
"video": frozenset({
".mp4", ".webm", ".avi", ".mkv", ".mov",
".flv", ".ts", ".m2ts", ".m4v", ".wmv",
}),
"audio": frozenset({
".mp3", ".wav", ".ogg", ".flac", ".aac",
".m4a", ".opus", ".wma", ".aiff", ".aif",
}),
}
# Flat set of all recognised extensions (used for directory scanning).
MEDIA_EXTENSIONS: frozenset[str] = frozenset().union(*CATEGORY_EXTENSIONS.values())
# Shared SSL context — server uses a self-signed cert.
_ssl_ctx = ssl.create_default_context()
_ssl_ctx.check_hostname = False
_ssl_ctx.verify_mode = ssl.CERT_NONE
# Prevents concurrent flush runs from uploading the same file twice.
_flush_lock = asyncio.Lock()
# ---------------------------------------------------------------------------
# Stats
# ---------------------------------------------------------------------------
@dataclass
class UploadStats:
"""Cumulative upload counters for the current bot session."""
total_attempted: int = 0
total_ok: int = 0
last_attempted: int = 0
last_ok: int = 0
@property
def total_fail(self) -> int:
return self.total_attempted - self.total_ok
@property
def last_fail(self) -> int:
return self.last_attempted - self.last_ok
@property
def fail_rate_pct(self) -> float:
if self.total_attempted == 0:
return 0.0
return (self.total_fail / self.total_attempted) * 100.0
_stats = UploadStats()
def get_stats() -> UploadStats:
"""Return the module-level upload stats (live reference)."""
return _stats
def is_running() -> bool:
"""Return True if a flush is currently in progress."""
return _flush_lock.locked()
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _content_type(filepath: Path) -> str:
mime, _ = mimetypes.guess_type(filepath.name)
return mime or "application/octet-stream"
def _get_category(suffix: str) -> str:
"""Return the upload folder category for a file extension."""
s = suffix.lower()
for category, extensions in CATEGORY_EXTENSIONS.items():
if s in extensions:
logger.info(f"[_get_category] File category: {category}")
return category
logger.info(f"[_get_category] File category: other")
return "other"
def _build_upload_name(filepath: Path) -> str:
"""Return a filename with the current datetime appended before the extension.
Example: ``ComfyUI_00042.png`` → ``20260225_143022_ComfyUI_00042.png``
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"{timestamp}_{filepath.stem}{filepath.suffix}"
def _build_url(url, category: str) -> str:
"""Build url to upload to each specific folder"""
url += "/comfyui"
if category == "image":
url += "/image"
elif category == "video":
url += "/video"
elif category == "gif":
url += "/gif"
elif category == "audio":
url += "/audio"
else:
url
return url
async def _upload_one(session: aiohttp.ClientSession, filepath: Path) -> bool:
"""Upload a single file. Returns True on HTTP 2xx."""
try:
file_bytes = filepath.read_bytes()
except OSError:
logger.warning("Cannot read file for upload: %s", filepath)
return False
category = _get_category(filepath.suffix)
upload_name = _build_upload_name(filepath)
form = aiohttp.FormData()
form.add_field(
"file",
file_bytes,
filename=upload_name,
content_type=_content_type(filepath),
)
url = _build_url(UPLOAD_URL, category)
logger.info(f"Uploading file to url: {url}")
try:
async with session.post(
url,
data=form,
timeout=aiohttp.ClientTimeout(total=120),
) as resp:
if resp.status // 100 == 2:
return True
body = await resp.text()
logger.warning(
"Upload rejected %s: HTTP %s%s",
upload_name,
resp.status,
body[:200],
)
return False
except Exception:
logger.warning("Upload error for %s", upload_name, exc_info=True)
return False
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def flush_pending(
output_path: Path,
user: Optional[str],
password: Optional[str],
) -> int:
"""
Scan *output_path* for media files, upload each to the storage server,
and delete the local file on success.
If *user* or *password* is falsy (not configured), returns 0 immediately
and leaves all files in place for the manual ttr!collect-videos command.
Files that fail to upload are left in place and will be retried on the
next call. If a previous flush is still running the new call returns 0
immediately to avoid double-uploading.
Returns the number of files successfully uploaded and deleted.
"""
if not user or not password:
return 0
if _flush_lock.locked():
logger.debug("flush_pending already in progress, skipping")
return 0
async with _flush_lock:
try:
entries = [
e for e in output_path.iterdir()
if e.is_file() and e.suffix.lower() in MEDIA_EXTENSIONS
]
except OSError:
logger.warning("Cannot scan output directory: %s", output_path)
return 0
if not entries:
_stats.last_attempted = 0
_stats.last_ok = 0
return 0
logger.info("Auto-uploading %d pending media file(s)…", len(entries))
auth = aiohttp.BasicAuth(user, password)
connector = aiohttp.TCPConnector(ssl=_ssl_ctx)
uploaded = 0
async with aiohttp.ClientSession(connector=connector, auth=auth) as session:
for filepath in entries:
if await _upload_one(session, filepath):
try:
filepath.unlink()
logger.info("Uploaded and deleted: %s", filepath.name)
uploaded += 1
except OSError:
logger.warning(
"Uploaded but could not delete local file: %s", filepath
)
# Update cumulative stats
_stats.last_attempted = len(entries)
_stats.last_ok = uploaded
_stats.total_attempted += len(entries)
_stats.total_ok += uploaded
failed = len(entries) - uploaded
if failed:
logger.warning(
"Auto-upload: %d ok, %d failed — will retry next generation.",
uploaded,
failed,
)
elif uploaded:
logger.info("Auto-upload complete: %d file(s) uploaded and deleted.", uploaded)
return uploaded

206
preset_manager.py Normal file
View File

@@ -0,0 +1,206 @@
"""
preset_manager.py
=================
Preset management for the Discord ComfyUI bot.
A preset is a named snapshot of the current workflow template and runtime
state (prompt, negative_prompt, input_image, seed). Presets are stored as
individual JSON files inside a ``presets/`` directory so they can be
inspected, backed up, or shared manually.
Usage via bot commands:
ttr!preset save <name> — capture current workflow + state
ttr!preset load <name> — restore workflow + state from snapshot
ttr!preset list — list all saved presets
ttr!preset delete <name> — permanently remove a preset
"""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from typing import Any, Optional
logger = logging.getLogger(__name__)
# Only allow alphanumeric characters, hyphens, and underscores in preset names.
_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
class PresetManager:
"""
Manages named workflow presets on disk.
Each preset is stored as ``<presets_dir>/<name>.json`` and contains:
- ``name``: the preset name
- ``workflow``: the full ComfyUI workflow template dict (may be null)
- ``state``: the runtime state changes (prompt, negative_prompt,
input_image, seed)
Parameters
----------
presets_dir : str
Directory where preset files are stored. Created automatically if
it does not exist.
"""
def __init__(self, presets_dir: str = "presets") -> None:
self.presets_dir = Path(presets_dir)
self.presets_dir.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def is_valid_name(name: str) -> bool:
"""Return True if the name only contains safe characters."""
return bool(_SAFE_NAME_RE.match(name))
def _path(self, name: str) -> Path:
"""Return the file path for a preset by name (no validation)."""
return self.presets_dir / f"{name}.json"
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def save(
self,
name: str,
workflow_template: Optional[dict[str, Any]],
state: dict[str, Any],
owner: Optional[str] = None,
description: Optional[str] = None,
) -> None:
"""
Save a preset to disk.
Parameters
----------
name : str
The preset name (alphanumeric, hyphens, underscores only).
workflow_template : Optional[dict]
The current workflow template, or None if none is loaded.
state : dict
The current runtime state from WorkflowStateManager.get_changes().
owner : Optional[str]
The user label of the preset creator. Stored for access control.
description : Optional[str]
A human-readable description of what this preset does.
Raises
------
ValueError
If the name contains invalid characters.
OSError
If the file cannot be written.
"""
if not self.is_valid_name(name):
raise ValueError(
f"Invalid preset name '{name}'. "
"Use only letters, digits, hyphens, and underscores (max 64 chars)."
)
data: dict[str, Any] = {"name": name, "workflow": workflow_template, "state": state}
if owner is not None:
data["owner"] = owner
if description is not None:
data["description"] = description
path = self._path(name)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
logger.info("Saved preset '%s' to %s", name, path)
def load(self, name: str) -> Optional[dict[str, Any]]:
"""
Load a preset from disk.
Parameters
----------
name : str
The preset name.
Returns
-------
Optional[dict]
The preset data dict, or None if the preset does not exist.
"""
if not self.is_valid_name(name):
return None
path = self._path(name)
if not path.exists():
return None
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as exc:
logger.warning("Failed to load preset '%s': %s", name, exc)
return None
def delete(self, name: str) -> bool:
"""
Delete a preset file.
Parameters
----------
name : str
The preset name.
Returns
-------
bool
True if the preset existed and was deleted, False otherwise.
"""
if not self.is_valid_name(name):
return False
path = self._path(name)
if path.exists():
path.unlink()
logger.info("Deleted preset '%s'", name)
return True
return False
def list_presets(self) -> list[str]:
"""
List all saved preset names, sorted alphabetically.
Returns
-------
list[str]
Sorted list of preset names (without the .json extension).
"""
return sorted(p.stem for p in self.presets_dir.glob("*.json"))
def list_preset_details(self) -> list[dict[str, Any]]:
"""
List all presets with their metadata, sorted alphabetically by name.
Returns
-------
list[dict]
Each entry has ``"name"``, ``"owner"`` (may be None), and
``"description"`` (may be None).
"""
result = []
for p in sorted(self.presets_dir.glob("*.json"), key=lambda x: x.stem):
owner = None
description = None
try:
with open(p, "r", encoding="utf-8") as f:
data = json.load(f)
owner = data.get("owner")
description = data.get("description")
except Exception:
pass
result.append({"name": p.stem, "owner": owner, "description": description})
return result
def exists(self, name: str) -> bool:
"""Return True if a preset with this name exists on disk."""
if not self.is_valid_name(name):
return False
return self._path(name).exists()

521
status_monitor.py Normal file
View File

@@ -0,0 +1,521 @@
"""
status_monitor.py
=================
Live status dashboard for the Discord ComfyUI bot.
Edits a single pinned message in a designated log channel every
``update_interval`` seconds. Changed values are highlighted with
bold text and directional arrows/emoji so differences are immediately
obvious.
Change-highlighting rules
-------------------------
- Unchanged good state → 🟢 value
- Unchanged bad state → 🔴 value
- Changed → bad → ⚠️ **value**
- Changed → good → ✅ **value**
- Changed (neutral) → **value**
- Queue size changed → **N** ▲ or **N** ▼
"""
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
import aiohttp
import discord
from commands.server import get_service_state, STATUS_EMOJI
from media_uploader import get_stats as get_upload_stats, is_running as upload_is_running, MEDIA_EXTENSIONS
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Module-level helpers (no discord.py dependency)
# ---------------------------------------------------------------------------
def _format_uptime(start: datetime) -> str:
"""Return a human-readable uptime string from a UTC start time."""
delta = datetime.now(timezone.utc) - start
total = int(delta.total_seconds())
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
if h:
return f"{h}h {m}m {s}s"
if m:
return f"{m}m {s}s"
return f"{s}s"
def _elapsed(start: datetime) -> str:
"""Return elapsed time string since *start* (UTC)."""
delta = datetime.now(timezone.utc) - start
total = int(delta.total_seconds())
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
if h:
return f"{h}:{m:02d}:{s:02d}"
return f"{m}:{s:02d}"
# ---------------------------------------------------------------------------
# StatusMonitor
# ---------------------------------------------------------------------------
class StatusMonitor:
"""
Periodically edits a single Discord message with live bot status.
Parameters
----------
bot :
The discord.ext.commands.Bot instance.
channel_id : int
ID of the Discord channel used for the dashboard message.
update_interval : float
Seconds between updates (default 5).
"""
HEADER_MARKER = "📊"
def __init__(self, bot, channel_id: int, update_interval: float = 10.0) -> None:
self._bot = bot
self._channel_id = channel_id
self._interval = update_interval
self._prev: dict[str, str] = {}
self._message: Optional[discord.Message] = None
self._task: Optional[asyncio.Task] = None
# ------------------------------------------------------------------
# Public lifecycle
# ------------------------------------------------------------------
async def start(self) -> None:
"""Start the update loop (idempotent)."""
if self._task is None or self._task.done():
self._task = asyncio.create_task(self._update_loop())
logger.info("StatusMonitor started for channel %s", self._channel_id)
async def stop(self) -> None:
"""Cancel the update loop, then send shutdown notice (idempotent)."""
if self._task and not self._task.done():
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
await self._send_shutdown_message()
logger.info("StatusMonitor stopped")
async def _send_shutdown_message(self) -> None:
"""Immediately edit the dashboard message to show bot offline status.
Uses a fresh aiohttp session with the bot token directly, because
discord.py closes its own HTTP session before our finally block runs
on Ctrl-C / task cancellation, making Message.edit() silently fail.
"""
if self._message is None:
# No cached message; can't create one safely during shutdown.
return
now = datetime.now(timezone.utc)
vn_time = now + timedelta(hours=7)
utc_str = now.strftime("%H:%M:%S UTC")
vn_str = vn_time.strftime("%H:%M:%S GMT+7")
text = (
f"{self.HEADER_MARKER} 🔴 **Bot Status Dashboard** — OFFLINE\n"
f"-# Shut down at: {utc_str} ({vn_str})\n"
"\n"
"Bot process has stopped."
)
try:
token = self._bot.http.token
url = (
f"https://discord.com/api/v10/channels/"
f"{self._channel_id}/messages/{self._message.id}"
)
headers = {
"Authorization": f"Bot {token}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.patch(url, json={"content": text}, headers=headers) as resp:
if resp.status not in (200, 204):
logger.warning(
"StatusMonitor: shutdown edit returned HTTP %s", resp.status
)
except Exception as exc:
logger.warning("StatusMonitor: could not send shutdown message: %s", exc)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _get_or_create_message(self) -> Optional[discord.Message]:
"""
Return the existing dashboard message or create a new one.
Searches the last 20 messages in the channel for an existing
dashboard (posted by this bot and containing the header marker).
"""
channel = self._bot.get_channel(self._channel_id)
if channel is None:
try:
channel = await self._bot.fetch_channel(self._channel_id)
except discord.NotFound:
logger.error("StatusMonitor: channel %s not found", self._channel_id)
return None
except discord.Forbidden:
logger.error("StatusMonitor: no access to channel %s", self._channel_id)
return None
# Try to find an existing dashboard message
try:
async for msg in channel.history(limit=20):
if msg.author == self._bot.user and self.HEADER_MARKER in msg.content:
return msg
except discord.HTTPException as exc:
logger.warning("StatusMonitor: history fetch failed: %s", exc)
# None found — create a fresh one
try:
msg = await channel.send(f"{self.HEADER_MARKER} **Bot Status Dashboard**\n-# Initializing…")
return msg
except discord.HTTPException as exc:
logger.error("StatusMonitor: could not create dashboard message: %s", exc)
return None
def _collect_sync(self) -> dict[str, str]:
"""
Read bot/workflow state synchronously (no async calls).
Returns a flat dict of string key → string value snapshots.
ComfyUI queue stats are filled in asynchronously in _update_loop.
"""
snap: dict[str, str] = {}
bot = self._bot
# --- Bot section ---
lat = bot.latency
latency_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
snap["latency"] = f"{latency_ms} ms"
if hasattr(bot, "start_time"):
snap["uptime"] = _format_uptime(bot.start_time)
else:
snap["uptime"] = "unknown"
# --- ComfyUI queue section (filled async) ---
snap["comfy_pending"] = self._prev.get("comfy_pending", "?")
snap["comfy_running"] = self._prev.get("comfy_running", "?")
# --- ComfyUI section ---
comfy = getattr(bot, "comfy", None)
if comfy is not None:
snap["comfy_server"] = getattr(comfy, "server_address", "unknown")
wm = getattr(comfy, "workflow_manager", None)
workflow_loaded = wm is not None and wm.get_workflow_template() is not None
snap["workflow"] = "loaded" if workflow_loaded else "none"
sm = getattr(comfy, "state_manager", None)
if sm is not None:
changes = sm.get_changes()
p = changes.get("prompt") or ""
snap["prompt"] = (p[:50] + "" if len(p) > 50 else p) if p else ""
n = changes.get("negative_prompt") or ""
snap["neg_prompt"] = (n[:50] + "" if len(n) > 50 else n) if n else ""
img = changes.get("input_image")
snap["input_image"] = Path(img).name if img else ""
seed_pin = changes.get("seed")
snap["pinned_seed"] = str(seed_pin) if seed_pin is not None else "random"
else:
snap["prompt"] = ""
snap["neg_prompt"] = ""
snap["input_image"] = ""
snap["pinned_seed"] = ""
last_seed = getattr(comfy, "last_seed", None)
snap["last_seed"] = str(last_seed) if last_seed is not None else ""
snap["total_gen"] = str(getattr(comfy, "total_generated", 0))
else:
snap["comfy_server"] = "not configured"
snap["workflow"] = ""
snap["prompt"] = ""
snap["neg_prompt"] = ""
snap["input_image"] = ""
snap["pinned_seed"] = ""
snap["last_seed"] = ""
snap["total_gen"] = "0"
# comfy_status and service_state are filled in asynchronously
snap["comfy_status"] = self._prev.get("comfy_status", "unknown")
snap["service_state"] = self._prev.get("service_state", "unknown")
# --- Auto-upload section ---
config = getattr(bot, "config", None)
upload_user = getattr(config, "media_upload_user", None)
upload_configured = bool(upload_user)
snap["upload_configured"] = "enabled" if upload_configured else "disabled"
if upload_configured:
snap["upload_state"] = "uploading" if upload_is_running() else "idle"
# Pending: count media files sitting in the output directory
output_path_str = getattr(config, "comfy_output_path", None)
if output_path_str:
try:
pending = sum(
1 for e in Path(output_path_str).iterdir()
if e.is_file() and e.suffix.lower() in MEDIA_EXTENSIONS
)
except OSError:
pending = 0
snap["upload_pending"] = str(pending)
else:
snap["upload_pending"] = ""
us = get_upload_stats()
if us.total_attempted > 0:
snap["upload_session"] = (
f"{us.total_ok} ok, {us.total_fail} failed"
f" ({us.fail_rate_pct:.1f}%)"
)
else:
snap["upload_session"] = "no uploads yet"
if us.last_attempted > 0:
snap["upload_last"] = f"{us.last_ok} ok, {us.last_fail} failed"
else:
snap["upload_last"] = ""
else:
snap["upload_state"] = ""
snap["upload_pending"] = ""
snap["upload_session"] = ""
snap["upload_last"] = ""
return snap
async def _check_connection(self) -> str:
"""Async check whether ComfyUI is reachable. Returns a plain string."""
comfy = getattr(self._bot, "comfy", None)
if comfy is None:
return "not configured"
try:
reachable = await asyncio.wait_for(comfy.check_connection(), timeout=4.0)
return "reachable" if reachable else "unreachable"
except asyncio.TimeoutError:
return "unreachable"
except Exception:
return "unreachable"
async def _check_comfy_queue(self) -> dict[str, str]:
"""Fetch ComfyUI queue depths. Returns {comfy_pending, comfy_running}."""
comfy = getattr(self._bot, "comfy", None)
if comfy is None:
return {"comfy_pending": "?", "comfy_running": "?"}
try:
q = await asyncio.wait_for(comfy.get_comfy_queue(), timeout=4.0)
if q:
return {
"comfy_pending": str(len(q.get("queue_pending", []))),
"comfy_running": str(len(q.get("queue_running", []))),
}
except Exception:
pass
return {"comfy_pending": "?", "comfy_running": "?"}
async def _check_service_state(self) -> str:
"""Return the NSSM service state string for the configured ComfyUI service."""
config = getattr(self._bot, "config", None)
if config is None:
return "unknown"
service_name = getattr(config, "comfy_service_name", None)
if not service_name:
return "unknown"
return await get_service_state(service_name)
# ------------------------------------------------------------------
# Change-detection formatting
# ------------------------------------------------------------------
def _fmt(self, key: str, value: str, *, good: str, bad: str) -> str:
"""
Format *value* with change-detection highlighting.
Parameters
----------
key :
Snapshot key used to look up the previous value.
value :
Current value string.
good :
The value string that represents a "good" state.
bad :
The value string that represents a "bad" state.
"""
prev = self._prev.get(key)
changed = prev is not None and prev != value
is_good = value == good
is_bad = value == bad
if not changed:
if is_good:
return f"🟢 {value}"
if is_bad:
return f"🔴 {value}"
return value
# Value changed
if is_bad:
return f"⚠️ **{value}**"
if is_good:
return f"✅ **{value}**"
return f"**{value}**"
def _fmt_service_state(self, value: str) -> str:
"""Format NSSM service state with emoji and change-detection highlighting."""
prev = self._prev.get("service_state")
changed = prev is not None and prev != value
emoji = STATUS_EMOJI.get(value, "")
if not changed:
return f"{emoji} {value}"
if value == "SERVICE_RUNNING":
return f"✅ **{value}**"
if value in ("SERVICE_STOPPED", "error", "timeout"):
return f"⚠️ **{value}**"
return f"**{value}**"
def _fmt_queue_size(self, value: str, prev_key: str) -> str:
"""Format queue size with ▲/▼ arrows when changed."""
prev = self._prev.get(prev_key)
if prev is None or prev == value:
return value
try:
arrow = "" if int(value) > int(prev) else ""
except ValueError:
arrow = ""
return f"**{value}** {arrow}" if arrow else f"**{value}**"
# ------------------------------------------------------------------
# Message assembly
# ------------------------------------------------------------------
def _build_message(self, snap: dict[str, str], now: datetime) -> str:
"""Assemble the full dashboard message string."""
vn_time = now + timedelta(hours=7)
timestamp = f"{now.strftime('%H:%M:%S UTC')} ({vn_time.strftime('%H:%M:%S GMT+7')})"
pending_fmt = self._fmt_queue_size(snap["comfy_pending"], "comfy_pending")
running_fmt = self._fmt_queue_size(snap["comfy_running"], "comfy_running")
http_fmt = self._fmt("comfy_status", snap["comfy_status"], good="reachable", bad="unreachable")
svc_fmt = self._fmt_service_state(snap["service_state"])
seed_fmt = self._fmt("last_seed", snap["last_seed"], good="", bad="")
prompt_fmt = self._fmt("prompt", snap["prompt"], good="", bad="")
neg_fmt = self._fmt("neg_prompt", snap["neg_prompt"], good="", bad="")
image_fmt = self._fmt("input_image", snap["input_image"], good="", bad="")
pinned_fmt = self._fmt("pinned_seed", snap["pinned_seed"], good="", bad="")
upload_state_fmt = self._fmt(
"upload_state", snap["upload_state"], good="idle", bad=""
)
upload_pending_fmt = self._fmt(
"upload_pending", snap["upload_pending"], good="0", bad=""
)
upload_session_fmt = self._fmt("upload_session", snap["upload_session"], good="", bad="")
upload_last_fmt = self._fmt("upload_last", snap["upload_last"], good="", bad="")
lines = [
f"{self.HEADER_MARKER} **Bot Status Dashboard**",
f"-# Last updated: {timestamp}",
"",
"**Bot**",
f" Latency : {snap['latency']}",
f" Uptime : {snap['uptime']}",
"",
f"**ComfyUI** — `{snap['comfy_server']}`",
f" Service : {svc_fmt}",
f" HTTP : {http_fmt}",
f" Queue : {running_fmt} running, {pending_fmt} pending",
f" Workflow : {snap['workflow']}",
f" Prompt : || {prompt_fmt} ||",
f" Neg : || {neg_fmt} ||",
f" Image : {image_fmt}",
f" Seed : {pinned_fmt}",
"",
"**Last Generation**",
f" Seed : {seed_fmt}",
f" Total : {snap['total_gen']}",
]
if snap["upload_configured"] == "enabled":
lines += [
"",
"**Auto-Upload**",
f" State : {upload_state_fmt}",
f" Pending : {upload_pending_fmt}",
f" Session : {upload_session_fmt}",
f" Last run : {upload_last_fmt}",
]
else:
lines += [
"",
"**Auto-Upload** — disabled *(set MEDIA_UPLOAD_USER / MEDIA_UPLOAD_PASS to enable)*",
]
return "\n".join(lines)
# ------------------------------------------------------------------
# Update loop
# ------------------------------------------------------------------
async def _update_loop(self) -> None:
"""Background task: collect state, build message, edit in place."""
await self._bot.wait_until_ready()
while True:
try:
now = datetime.now(timezone.utc)
# Collect synchronous state
snap = self._collect_sync()
# Async checks run concurrently
comfy_status, service_state, queue_stats = await asyncio.gather(
self._check_connection(),
self._check_service_state(),
self._check_comfy_queue(),
)
snap["comfy_status"] = comfy_status
snap["service_state"] = service_state
snap.update(queue_stats)
# Build message text
text = self._build_message(snap, now)
# Ensure we have a message to edit
if self._message is None:
self._message = await self._get_or_create_message()
if self._message is not None:
try:
await self._message.edit(content=text)
except discord.NotFound:
# Message was deleted — recreate next cycle
self._message = None
except (discord.HTTPException, OSError) as exc:
logger.warning("StatusMonitor: edit failed: %s", exc)
# Save snapshot for next cycle's change detection
self._prev = snap
except asyncio.CancelledError:
raise
except Exception:
logger.exception("StatusMonitor: unexpected error in update loop")
await asyncio.sleep(self._interval)

162
token_store.py Normal file
View File

@@ -0,0 +1,162 @@
"""
token_store.py
==============
Invite-token CRUD for the web UI.
Tokens are stored as SHA-256 hashes so the plaintext is never at rest.
The plaintext token is only returned once (at creation time).
File format (invite_tokens.json)::
[
{
"id": "uuid",
"label": "alice",
"hash": "<sha256-hex>",
"admin": false,
"created_at": "2024-01-01T00:00:00"
},
...
]
Usage::
# Create a token (CLI)
python -c "from token_store import create_token; print(create_token('alice'))"
# Verify a token (used by auth.py)
from token_store import verify_token
record = verify_token(plaintext, token_file)
"""
from __future__ import annotations
import hashlib
import json
import logging
import secrets
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
def _hash(token: str) -> str:
return hashlib.sha256(token.encode()).hexdigest()
def _load(token_file: str) -> list[dict]:
path = Path(token_file)
if not path.exists():
return []
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as exc:
logger.warning("Failed to load token file %s: %s", token_file, exc)
return []
def _save(token_file: str, records: list[dict]) -> None:
with open(token_file, "w", encoding="utf-8") as f:
json.dump(records, f, indent=2)
def create_token(
label: str,
token_file: str = "invite_tokens.json",
*,
admin: bool = False,
) -> str:
"""
Create a new invite token, persist its hash, and return the plaintext.
Parameters
----------
label : str
Human-readable name for this token (e.g. ``"alice"``).
token_file : str
Path to the JSON store.
admin : bool
If True, grants admin privileges.
Returns
-------
str
The plaintext token (shown once; not stored).
"""
plaintext = secrets.token_urlsafe(32)
record = {
"id": str(uuid.uuid4()),
"label": label,
"hash": _hash(plaintext),
"admin": admin,
"created_at": datetime.now(timezone.utc).isoformat(),
}
records = _load(token_file)
records.append(record)
_save(token_file, records)
logger.info("Created token for '%s' (admin=%s)", label, admin)
return plaintext
def verify_token(
plaintext: str,
token_file: str = "invite_tokens.json",
) -> Optional[dict]:
"""
Verify a plaintext token against the stored hashes.
Parameters
----------
plaintext : str
The token string provided by the user.
token_file : str
Path to the JSON store.
Returns
-------
Optional[dict]
The matching record dict (with ``label`` and ``admin`` fields),
or None if no match.
"""
h = _hash(plaintext)
for record in _load(token_file):
if record.get("hash") == h:
return record
return None
def list_tokens(token_file: str = "invite_tokens.json") -> list[dict]:
"""Return all token records (hashes included, labels safe to show)."""
return _load(token_file)
def revoke_token(token_id: str, token_file: str = "invite_tokens.json") -> bool:
"""
Delete a token by its UUID.
Returns
-------
bool
True if found and deleted, False if not found.
"""
records = _load(token_file)
new_records = [r for r in records if r.get("id") != token_id]
if len(new_records) == len(records):
return False
_save(token_file, new_records)
logger.info("Revoked token %s", token_id)
return True
if __name__ == "__main__":
import sys
label = sys.argv[1] if len(sys.argv) > 1 else "default"
is_admin = "--admin" in sys.argv
tok = create_token(label, admin=is_admin)
print(f"Token for '{label}': {tok}")

147
user_state_registry.py Normal file
View File

@@ -0,0 +1,147 @@
"""
user_state_registry.py
======================
Per-user workflow state + template registry.
Each web-UI user gets their own isolated WorkflowStateManager (persisted to
``user_settings/<user_label>.json``) and workflow template.
New users (no saved file) fall back to the global default workflow template
loaded at startup (WORKFLOW_FILE env var or last-used workflow from the
global Discord state manager).
Discord continues to use the shared global state/workflow manager — this
registry is only used by the web UI layer.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from workflow_state import WorkflowStateManager
logger = logging.getLogger(__name__)
_PROJECT_ROOT = Path(__file__).resolve().parent
class UserStateRegistry:
"""
Per-user isolated workflow state and template store.
Parameters
----------
settings_dir : Path
Directory where per-user state files are stored. Created automatically.
default_workflow : Optional[dict]
The global default workflow template. Used when a user has no saved
``last_workflow_file``, or the file no longer exists.
"""
def __init__(
self,
settings_dir: Path,
default_workflow: Optional[dict[str, Any]] = None,
) -> None:
self._settings_dir = settings_dir
self._settings_dir.mkdir(parents=True, exist_ok=True)
self._default_workflow: Optional[dict[str, Any]] = default_workflow
# user_label → WorkflowStateManager
self._managers: Dict[str, WorkflowStateManager] = {}
# user_label → workflow template dict (or None)
self._templates: Dict[str, Optional[dict[str, Any]]] = {}
# ------------------------------------------------------------------
# Default workflow
# ------------------------------------------------------------------
def set_default_workflow(self, template: Optional[dict[str, Any]]) -> None:
"""Update the global fallback workflow (called when bot workflow changes)."""
self._default_workflow = template
# ------------------------------------------------------------------
# User access
# ------------------------------------------------------------------
def get_state_manager(self, user_label: str) -> WorkflowStateManager:
"""Return (or lazily create) the WorkflowStateManager for a user."""
if user_label not in self._managers:
self._init_user(user_label)
return self._managers[user_label]
def get_workflow_template(self, user_label: str) -> Optional[dict[str, Any]]:
"""Return the workflow template for a user, or None if not loaded."""
if user_label not in self._managers:
self._init_user(user_label)
return self._templates.get(user_label)
def set_workflow(
self, user_label: str, template: dict[str, Any], filename: str
) -> None:
"""
Store a workflow template for a user and persist the filename.
Clears existing overrides (matches the behaviour of loading a new
workflow via the global state manager).
"""
if user_label not in self._managers:
self._init_user(user_label)
sm = self._managers[user_label]
sm.clear_overrides()
sm.set_last_workflow_file(filename)
self._templates[user_label] = template
logger.debug(
"UserStateRegistry: set workflow '%s' for user '%s'", filename, user_label
)
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _init_user(self, user_label: str) -> None:
"""
Initialise state manager and workflow template for a new user.
1. Create WorkflowStateManager with per-user state file.
2. If the file recorded a last_workflow_file, try to load it.
3. Fall back to the global default template.
"""
state_file = str(self._settings_dir / f"{user_label}.json")
sm = WorkflowStateManager(state_file=state_file)
self._managers[user_label] = sm
# Try to restore the last workflow this user loaded
last_wf = sm.get_last_workflow_file()
template: Optional[dict[str, Any]] = None
if last_wf:
wf_path = _PROJECT_ROOT / "workflows" / last_wf
if wf_path.exists():
try:
with open(wf_path, "r", encoding="utf-8") as f:
template = json.load(f)
logger.debug(
"UserStateRegistry: restored workflow '%s' for user '%s'",
last_wf,
user_label,
)
except Exception as exc:
logger.warning(
"UserStateRegistry: could not load '%s' for user '%s': %s",
last_wf,
user_label,
exc,
)
else:
logger.debug(
"UserStateRegistry: last workflow '%s' missing for user '%s'; using default",
last_wf,
user_label,
)
if template is None:
template = self._default_workflow
self._templates[user_label] = template

309
wan2.2-fast.json Normal file
View File

@@ -0,0 +1,309 @@
{
"1": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "Load CLIP"
}
},
"2": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"3": {
"inputs": {
"unet_name": "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"4": {
"inputs": {
"unet_name": "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"5": {
"inputs": {
"shift": 5.000000000000001,
"model": [
"13",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"6": {
"inputs": {
"fps": 16,
"images": [
"7",
0
]
},
"class_type": "CreateVideo",
"_meta": {
"title": "Create Video"
}
},
"7": {
"inputs": {
"samples": [
"8",
0
],
"vae": [
"2",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"8": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 4,
"cfg": 1,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 2,
"end_at_step": 4,
"return_with_leftover_noise": "disable",
"model": [
"5",
0
],
"positive": [
"17",
0
],
"negative": [
"17",
1
],
"latent_image": [
"11",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"9": {
"inputs": {
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
"strength_model": 1.0000000000000002,
"model": [
"3",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoraLoaderModelOnly"
}
},
"10": {
"inputs": {
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
"strength_model": 1.0000000000000002,
"model": [
"4",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoraLoaderModelOnly"
}
},
"11": {
"inputs": {
"add_noise": "enable",
"noise_seed": 626287185902791,
"steps": 4,
"cfg": 1,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 2,
"return_with_leftover_noise": "enable",
"model": [
"12",
0
],
"positive": [
"17",
0
],
"negative": [
"17",
1
],
"latent_image": [
"17",
2
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "KSampler (Advanced)"
}
},
"12": {
"inputs": {
"shift": 5.000000000000001,
"model": [
"14",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "ModelSamplingSD3"
}
},
"13": {
"inputs": {
"lora_name": "NSFW-22-L-e8.safetensors",
"strength_model": 1,
"model": [
"10",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoraLoaderModelOnly"
}
},
"14": {
"inputs": {
"lora_name": "NSFW-22-H-e8.safetensors",
"strength_model": 1,
"model": [
"9",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoraLoaderModelOnly"
}
},
"15": {
"inputs": {
"filename_prefix": "ComfyUI",
"format": "auto",
"codec": "auto",
"video": [
"6",
0
]
},
"class_type": "SaveVideo",
"_meta": {
"title": "Save Video"
}
},
"16": {
"inputs": {
"text": "色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走, clothes",
"clip": [
"1",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"17": {
"inputs": {
"width": 320,
"height": 640,
"length": 105,
"batch_size": 1,
"positive": [
"18",
0
],
"negative": [
"16",
0
],
"vae": [
"2",
0
],
"start_image": [
"19",
0
]
},
"class_type": "WanImageToVideo",
"_meta": {
"title": "WanImageToVideo"
}
},
"18": {
"inputs": {
"text": "A girl undress her clothes, ((completely naked)), showing her ((breasts)), masturbating",
"clip": [
"1",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"19": {
"inputs": {
"image": "611667883_3429056790577893_8252297829424260857_n.jpg"
},
"class_type": "LoadImage",
"_meta": {
"title": "Load Image"
}
},
"20": {
"inputs": {
"anything": [
"6",
0
]
},
"class_type": "easy cleanGpuUsed",
"_meta": {
"title": "Clean VRAM Used"
}
}
}

1
web/__init__.py Normal file
View File

@@ -0,0 +1 @@
# web package

269
web/app.py Normal file
View File

@@ -0,0 +1,269 @@
"""
web/app.py
==========
FastAPI application factory.
The app is created once and shared between the Uvicorn server (started
from bot.py via asyncio.gather) and tests.
Startup tasks:
- Background status ticker (broadcasts status_snapshot every 5s to all clients)
- Background NSSM poll (broadcasts server_state every 10s to all clients)
"""
from __future__ import annotations
import asyncio
import logging
import mimetypes
import os
from pathlib import Path
from fastapi import FastAPI
from starlette.exceptions import HTTPException as _HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request as _Request
# Windows registry can map .js → text/plain; override to the correct types
# before StaticFiles reads them.
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/wasm", ".wasm")
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
class _NoCacheHTMLMiddleware(BaseHTTPMiddleware):
"""Force browsers to revalidate index.html on every request.
Vite hashes JS/CSS filenames on every build so those assets are
naturally cache-busted. index.html itself has a stable name, so
without an explicit Cache-Control header mobile browsers apply
heuristic caching and keep serving a stale copy after a redeploy.
"""
async def dispatch(self, request: _Request, call_next):
response = await call_next(request)
ct = response.headers.get("content-type", "")
if "text/html" in ct:
response.headers["Cache-Control"] = "no-cache, must-revalidate"
return response
class _SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to every response."""
_CSP = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: blob:; "
"connect-src 'self' wss:; "
"frame-ancestors 'none';"
)
async def dispatch(self, request: _Request, call_next):
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Content-Security-Policy"] = self._CSP
return response
class _SPAStaticFiles(StaticFiles):
"""StaticFiles with SPA fallback: serve index.html for unknown paths.
Starlette's html=True only serves index.html for directory requests.
This subclass additionally returns index.html for any path that has no
matching file, so client-side routes like /generate work on refresh.
"""
async def get_response(self, path: str, scope):
try:
return await super().get_response(path, scope)
except _HTTPException as ex:
if ex.status_code == 404:
return await super().get_response("index.html", scope)
raise
from web.ws_bus import get_bus
logger = logging.getLogger(__name__)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
_WEB_STATIC = _PROJECT_ROOT / "web-static"
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
app = FastAPI(
title="ComfyUI Bot Web UI",
version="1.0.0",
docs_url=None,
redoc_url=None,
openapi_url=None,
)
# CORS — only allow explicitly configured origins; empty = no cross-origin
_cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()]
if _cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security headers on every response
app.add_middleware(_SecurityHeadersMiddleware)
# Prevent browsers from caching index.html across deploys
app.add_middleware(_NoCacheHTMLMiddleware)
# Register API routers
from web.routers.auth_router import router as auth_router
from web.routers.admin_router import router as admin_router
from web.routers.status_router import router as status_router
from web.routers.state_router import router as state_router
from web.routers.generate_router import router as generate_router
from web.routers.inputs_router import router as inputs_router
from web.routers.presets_router import router as presets_router
from web.routers.server_router import router as server_router
from web.routers.history_router import router as history_router
from web.routers.share_router import router as share_router
from web.routers.workflow_router import router as workflow_router
from web.routers.ws_router import router as ws_router
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
app.include_router(status_router, prefix="/api", tags=["status"])
app.include_router(state_router, prefix="/api", tags=["state"])
app.include_router(generate_router, prefix="/api", tags=["generate"])
app.include_router(inputs_router, prefix="/api/inputs", tags=["inputs"])
app.include_router(presets_router, prefix="/api/presets", tags=["presets"])
app.include_router(server_router, prefix="/api", tags=["server"])
app.include_router(history_router, prefix="/api/history", tags=["history"])
app.include_router(share_router, prefix="/api/share", tags=["share"])
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
app.include_router(ws_router, tags=["ws"])
# Serve frontend static files (if built)
if _WEB_STATIC.exists() and any(_WEB_STATIC.iterdir()):
app.mount("/", _SPAStaticFiles(directory=str(_WEB_STATIC), html=True), name="static")
logger.info("Serving frontend from %s", _WEB_STATIC)
@app.on_event("startup")
async def _startup():
asyncio.create_task(_status_ticker())
asyncio.create_task(_server_state_poller())
logger.info("Web background tasks started")
return app
# ---------------------------------------------------------------------------
# Background tasks
# ---------------------------------------------------------------------------
async def _status_ticker() -> None:
"""Broadcast status_snapshot to all clients every 5 seconds."""
from web.deps import get_bot, get_comfy, get_config
bus = get_bus()
while True:
await asyncio.sleep(5)
try:
bot = get_bot()
comfy = get_comfy()
config = get_config()
snapshot: dict = {}
if bot is not None:
lat = bot.latency
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
import datetime as _dt
start = getattr(bot, "start_time", None)
uptime = ""
if start:
delta = _dt.datetime.now(_dt.timezone.utc) - start
total = int(delta.total_seconds())
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
snapshot["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
if comfy is not None:
q = await comfy.get_comfy_queue()
pending = len(q.get("queue_pending", [])) if q else 0
running = len(q.get("queue_running", [])) if q else 0
wm = getattr(comfy, "workflow_manager", None)
wf_loaded = wm is not None and wm.get_workflow_template() is not None
snapshot["comfy"] = {
"server": comfy.server_address,
"queue_pending": pending,
"queue_running": running,
"workflow_loaded": wf_loaded,
"last_seed": comfy.last_seed,
"total_generated": comfy.total_generated,
}
if config is not None:
from media_uploader import get_stats as get_upload_stats, is_running as upload_running
try:
us = get_upload_stats()
snapshot["upload"] = {
"configured": bool(config.media_upload_user),
"running": upload_running(),
"total_ok": us.total_ok,
"total_fail": us.total_fail,
}
except Exception:
pass
from web.deps import get_user_registry
registry = get_user_registry()
connected = bus.connected_users
if connected and registry:
for ul in connected:
user_overrides = registry.get_state_manager(ul).get_overrides()
await bus.broadcast_to_user(ul, "status_snapshot", {**snapshot, "overrides": user_overrides})
else:
await bus.broadcast("status_snapshot", snapshot)
except Exception as exc:
logger.debug("Status ticker error: %s", exc)
async def _server_state_poller() -> None:
"""Poll NSSM service state and broadcast server_state every 10 seconds."""
from web.deps import get_config
bus = get_bus()
while True:
await asyncio.sleep(10)
try:
config = get_config()
if config is None:
continue
from commands.server import get_service_state
from web.deps import get_comfy
async def _false():
return False
comfy = get_comfy()
service_state, http_reachable = await asyncio.gather(
get_service_state(config.comfy_service_name),
comfy.check_connection() if comfy else _false(),
)
await bus.broadcast("server_state", {
"state": service_state,
"http_reachable": http_reachable,
})
except Exception as exc:
logger.debug("Server state poller error: %s", exc)

119
web/auth.py Normal file
View File

@@ -0,0 +1,119 @@
"""
web/auth.py
===========
JWT authentication for the web UI.
Flow:
- POST /api/auth/login {token} → verify invite token → issue JWT in httpOnly cookie
- All /api/* require valid JWT via require_auth dependency
- POST /api/admin/login {password} → issue admin JWT (admin: true claim)
- WS /ws?token=<jwt> → authenticate via query param
JWT claims: {"sub": "<label>", "admin": bool, "exp": ...}
"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import Cookie, Depends, HTTPException, status
from fastapi.security import HTTPBearer
try:
from jose import JWTError, jwt
except ImportError:
jwt = None # type: ignore
JWTError = Exception # type: ignore
logger = logging.getLogger(__name__)
ALGORITHM = "HS256"
_COOKIE_NAME = "ttb_session"
def _get_secret() -> str:
from web.deps import get_config
cfg = get_config()
if cfg and cfg.web_secret_key:
return cfg.web_secret_key
raise RuntimeError(
"WEB_SECRET_KEY must be set in the environment — "
"refusing to run with an insecure default."
)
def create_jwt(label: str, *, admin: bool = False, expire_hours: int = 8) -> str:
"""Create a signed JWT for the given user label."""
if jwt is None:
raise RuntimeError("python-jose is not installed (pip install python-jose[cryptography])")
expire = datetime.now(timezone.utc) + timedelta(hours=expire_hours)
payload = {"sub": label, "admin": admin, "exp": expire}
return jwt.encode(payload, _get_secret(), algorithm=ALGORITHM)
def decode_jwt(token: str) -> Optional[dict]:
"""Decode and verify a JWT. Returns the payload or None on failure."""
if jwt is None:
return None
try:
return jwt.decode(token, _get_secret(), algorithms=[ALGORITHM])
except JWTError as exc:
logger.debug("JWT decode failed: %s", exc)
return None
def verify_ws_token(token: str) -> Optional[dict]:
"""Verify a JWT passed as a WebSocket query parameter."""
return decode_jwt(token)
# ---------------------------------------------------------------------------
# FastAPI dependencies
# ---------------------------------------------------------------------------
def require_auth(ttb_session: Optional[str] = Cookie(default=None)) -> dict:
"""
FastAPI dependency that requires a valid JWT cookie.
Returns
-------
dict
The decoded JWT payload (``sub``, ``admin`` fields).
Raises
------
HTTPException 401
If the cookie is absent or the token is invalid/expired.
"""
if not ttb_session:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
)
payload = decode_jwt(ttb_session)
if payload is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
return payload
def require_admin(user: dict = Depends(require_auth)) -> dict:
"""
FastAPI dependency that requires an admin JWT.
Raises
------
HTTPException 403
If the token is valid but not admin.
"""
if not user.get("admin"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin access required",
)
return user

72
web/deps.py Normal file
View File

@@ -0,0 +1,72 @@
"""
web/deps.py
===========
Shared bot reference for FastAPI dependency injection.
``set_bot()`` is called once from ``bot.py`` before starting Uvicorn.
FastAPI route handlers use ``get_bot()``, ``get_comfy()``, etc. as
Depends() callables.
"""
from __future__ import annotations
from typing import Optional
_bot = None
def set_bot(bot) -> None:
"""Store the discord.py bot instance for DI access."""
global _bot
_bot = bot
def get_bot():
"""FastAPI dependency: return the bot instance."""
return _bot
def get_comfy():
"""FastAPI dependency: return the ComfyClient."""
if _bot is None:
return None
return getattr(_bot, "comfy", None)
def get_config():
"""FastAPI dependency: return the BotConfig."""
if _bot is None:
return None
return getattr(_bot, "config", None)
def get_state_manager():
"""FastAPI dependency: return the WorkflowStateManager."""
comfy = get_comfy()
if comfy is None:
return None
return getattr(comfy, "state_manager", None)
def get_workflow_manager():
"""FastAPI dependency: return the WorkflowManager."""
comfy = get_comfy()
if comfy is None:
return None
return getattr(comfy, "workflow_manager", None)
def get_inspector():
"""FastAPI dependency: return the WorkflowInspector."""
comfy = get_comfy()
if comfy is None:
return None
return getattr(comfy, "inspector", None)
def get_user_registry():
"""FastAPI dependency: return the UserStateRegistry."""
if _bot is None:
return None
return getattr(_bot, "user_registry", None)

146
web/login_guard.py Normal file
View File

@@ -0,0 +1,146 @@
"""
web/login_guard.py
==================
IP-based brute-force protection for login endpoints.
Tracks failed login attempts per IP in a rolling time window and issues a
temporary ban when the threshold is exceeded. Uses only stdlib — no new
pip packages required.
Usage
-----
from web.login_guard import get_guard, get_real_ip
@router.post("/login")
async def login(request: Request, body: LoginRequest, response: Response):
ip = get_real_ip(request)
get_guard().check(ip) # raises 429 if locked out
...
if failure:
get_guard().record_failure(ip)
raise HTTPException(401, ...)
get_guard().record_success(ip)
...
"""
from __future__ import annotations
import logging
import time
from collections import defaultdict
from typing import Dict, List
from fastapi import HTTPException, Request, status
logger = logging.getLogger(__name__)
def get_real_ip(request: Request) -> str:
"""Return the real client IP, honouring Cloudflare and common proxy headers.
Priority:
1. ``CF-Connecting-IP`` (set by Cloudflare)
2. ``X-Real-IP`` (set by nginx/traefik)
3. ``request.client.host`` (direct connection fallback)
"""
cf_ip = request.headers.get("CF-Connecting-IP", "").strip()
if cf_ip:
return cf_ip
real_ip = request.headers.get("X-Real-IP", "").strip()
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
class BruteForceGuard:
"""Rolling-window failure counter with automatic IP bans.
All state is in-process memory. A restart clears all bans and counters,
which is acceptable — a brief restart already provides a natural backoff
for a legitimate attacker.
"""
WINDOW_SECS = 600 # rolling window: 10 minutes
MAX_FAILURES = 10 # max failures before ban
BAN_SECS = 3600 # ban duration: 1 hour
def __init__(self) -> None:
# ip → list of failure timestamps (epoch floats)
self._failures: Dict[str, List[float]] = defaultdict(list)
# ip → ban expiry timestamp
self._ban_until: Dict[str, float] = {}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def check(self, ip: str) -> None:
"""Raise HTTP 429 if the IP is banned or has exceeded the failure threshold.
Call this *before* doing any credential work so the lockout is
evaluated even when the request body is malformed.
"""
now = time.time()
# Active ban?
ban_expiry = self._ban_until.get(ip, 0)
if ban_expiry > now:
logger.warning("login_guard: blocked request from banned ip=%s (ban expires in %.0fs)", ip, ban_expiry - now)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many failed attempts. Try again later.",
)
# Failure count within rolling window
cutoff = now - self.WINDOW_SECS
recent = [t for t in self._failures[ip] if t > cutoff]
self._failures[ip] = recent # prune stale entries while we're here
if len(recent) >= self.MAX_FAILURES:
# Threshold just reached — apply ban now
self._ban_until[ip] = now + self.BAN_SECS
logger.warning(
"login_guard: threshold reached, banning ip=%s for %ds",
ip, self.BAN_SECS,
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many failed attempts. Try again later.",
)
def record_failure(self, ip: str) -> None:
"""Record a failed login attempt for the given IP."""
now = time.time()
cutoff = now - self.WINDOW_SECS
recent = [t for t in self._failures[ip] if t > cutoff]
recent.append(now)
self._failures[ip] = recent
count = len(recent)
logger.warning("login_guard: failure #%d from ip=%s", count, ip)
if count >= self.MAX_FAILURES:
self._ban_until[ip] = now + self.BAN_SECS
logger.warning(
"login_guard: threshold reached, banning ip=%s for %ds",
ip, self.BAN_SECS,
)
def record_success(self, ip: str) -> None:
"""Clear failure history and any active ban for the given IP."""
self._failures.pop(ip, None)
self._ban_until.pop(ip, None)
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_guard: BruteForceGuard | None = None
def get_guard() -> BruteForceGuard:
"""Return the shared BruteForceGuard singleton (created on first call)."""
global _guard
if _guard is None:
_guard = BruteForceGuard()
return _guard

1
web/routers/__init__.py Normal file
View File

@@ -0,0 +1 @@
# web.routers package

View File

@@ -0,0 +1,88 @@
"""POST /api/admin/login; GET/POST/DELETE /api/admin/tokens"""
from __future__ import annotations
import hmac
import logging
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel
from web.auth import create_jwt, require_admin
from web.deps import get_config
from web.login_guard import get_guard, get_real_ip
router = APIRouter()
_COOKIE = "ttb_session"
audit = logging.getLogger("audit")
class AdminLoginRequest(BaseModel):
password: str
class CreateTokenRequest(BaseModel):
label: str
admin: bool = False
@router.post("/login")
async def admin_login(body: AdminLoginRequest, request: Request, response: Response):
"""Admin password login → admin JWT cookie."""
config = get_config()
expected_pw = config.admin_password if config else None
ip = get_real_ip(request)
get_guard().check(ip)
# Constant-time comparison to prevent timing attacks
if not expected_pw or not hmac.compare_digest(body.password, expected_pw):
get_guard().record_failure(ip)
audit.info("admin.login ip=%s success=False", ip)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Wrong password")
get_guard().record_success(ip)
audit.info("admin.login ip=%s success=True", ip)
expire_hours = config.web_jwt_expire_hours if config else 8
jwt_token = create_jwt("admin", admin=True, expire_hours=expire_hours)
response.set_cookie(
_COOKIE, jwt_token,
httponly=True, secure=True, samesite="strict",
max_age=expire_hours * 3600,
)
return {"label": "admin", "admin": True}
@router.get("/tokens")
async def list_tokens(_: dict = Depends(require_admin)):
"""List all invite tokens (hashes shown, labels safe)."""
from token_store import list_tokens as _list
config = get_config()
token_file = config.web_token_file if config else "invite_tokens.json"
records = _list(token_file)
# Don't return hashes to the UI
return [{"id": r["id"], "label": r["label"], "admin": r.get("admin", False),
"created_at": r.get("created_at")} for r in records]
@router.post("/tokens")
async def create_token(body: CreateTokenRequest, _: dict = Depends(require_admin)):
"""Create a new invite token. Returns the plaintext token (shown once)."""
from token_store import create_token as _create
config = get_config()
token_file = config.web_token_file if config else "invite_tokens.json"
plaintext = _create(body.label, token_file, admin=body.admin)
return {"token": plaintext, "label": body.label, "admin": body.admin}
@router.delete("/tokens/{token_id}")
async def revoke_token(token_id: str, _: dict = Depends(require_admin)):
"""Revoke an invite token by ID."""
from token_store import revoke_token as _revoke
config = get_config()
token_file = config.web_token_file if config else "invite_tokens.json"
ok = _revoke(token_id, token_file)
if not ok:
raise HTTPException(status_code=404, detail="Token not found")
return {"ok": True}

View File

@@ -0,0 +1,64 @@
"""POST /api/auth/login|logout; GET /api/auth/me"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel
from web.auth import create_jwt, require_auth
from web.deps import get_config
from web.login_guard import get_guard, get_real_ip
router = APIRouter()
_COOKIE = "ttb_session"
audit = logging.getLogger("audit")
class LoginRequest(BaseModel):
token: str
@router.post("/login")
async def login(body: LoginRequest, request: Request, response: Response):
"""Exchange an invite token for a JWT session cookie."""
from token_store import verify_token
config = get_config()
token_file = config.web_token_file if config else "invite_tokens.json"
expire_hours = config.web_jwt_expire_hours if config else 8
ip = get_real_ip(request)
get_guard().check(ip)
record = verify_token(body.token, token_file)
if record is None:
get_guard().record_failure(ip)
audit.info("auth.login ip=%s success=False", ip)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
label: str = record["label"]
admin: bool = record.get("admin", False)
get_guard().record_success(ip)
audit.info("auth.login ip=%s success=True label=%s", ip, label)
jwt_token = create_jwt(label, admin=admin, expire_hours=expire_hours)
response.set_cookie(
_COOKIE, jwt_token,
httponly=True, secure=True, samesite="strict",
max_age=expire_hours * 3600,
)
return {"label": label, "admin": admin}
@router.post("/logout")
async def logout(response: Response):
"""Clear the session cookie."""
response.delete_cookie(_COOKIE)
return {"ok": True}
@router.get("/me")
async def me(user: dict = Depends(require_auth)):
"""Return current user info."""
return {"label": user["sub"], "admin": user.get("admin", False)}

View File

@@ -0,0 +1,255 @@
"""POST /api/generate and /api/workflow-gen"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from web.auth import require_auth
from web.deps import get_comfy, get_config, get_user_registry
from web.ws_bus import get_bus
router = APIRouter()
logger = logging.getLogger(__name__)
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
overrides: Optional[Dict[str, Any]] = None # extra per-request overrides
class WorkflowGenRequest(BaseModel):
count: int = 1
overrides: Optional[Dict[str, Any]] = None # per-request overrides (merged with state)
@router.post("/generate")
async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
"""Submit a prompt-based generation to ComfyUI."""
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI client not available")
user_label: str = user["sub"]
bus = get_bus()
registry = get_user_registry()
# Temporary seed override from request
if body.overrides and "seed" in body.overrides:
seed_override = body.overrides["seed"]
elif registry:
seed_override = registry.get_state_manager(user_label).get_seed()
else:
seed_override = comfy.state_manager.get_seed()
overrides_for_gen = {"prompt": body.prompt}
if body.negative_prompt:
overrides_for_gen["negative_prompt"] = body.negative_prompt
if seed_override is not None:
overrides_for_gen["seed"] = seed_override
# Also apply any extra per-request overrides
if body.overrides:
overrides_for_gen.update(body.overrides)
# Get queue position estimate
depth = await comfy.get_queue_depth()
# Start generation as background task so we can return the prompt_id immediately
prompt_id_holder: list = []
async def _run():
# Use the user's own workflow template
if registry:
template = registry.get_workflow_template(user_label)
else:
template = comfy.workflow_manager.get_workflow_template()
if not template:
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": "No workflow template loaded"
})
return
import uuid
pid = str(uuid.uuid4())
prompt_id_holder.append(pid)
def on_progress(node, pid_):
asyncio.create_task(bus.broadcast("node_executing", {
"node": node, "prompt_id": pid_
}))
workflow, applied = comfy.inspector.inject_overrides(template, overrides_for_gen)
seed_used = applied.get("seed")
comfy.last_seed = seed_used
try:
images, videos = await comfy._general_generate(workflow, pid, on_progress)
except Exception as exc:
logger.exception("Generation error for prompt %s", pid)
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": pid, "error": str(exc)
})
return
comfy.last_prompt_id = pid
comfy.total_generated += 1
# Persist to DB before flush_pending deletes local files
config = get_config()
try:
from generation_db import record_generation, record_file
gen_id = record_generation(pid, "web", user_label, overrides_for_gen, seed_used)
for i, img_data in enumerate(images):
record_file(gen_id, f"image_{i:04d}.png", img_data)
if config and videos:
for vid in videos:
vsub = vid.get("video_subfolder", "")
vname = vid.get("video_name", "")
vpath = (
Path(config.comfy_output_path) / vsub / vname
if vsub
else Path(config.comfy_output_path) / vname
)
try:
record_file(gen_id, vname, vpath.read_bytes())
except OSError:
pass
except Exception as exc:
logger.warning("Failed to record generation to DB: %s", exc)
# Flush auto-upload
if config:
from media_uploader import flush_pending
asyncio.create_task(flush_pending(
Path(config.comfy_output_path),
config.media_upload_user,
config.media_upload_pass,
))
await bus.broadcast("queue_update", {
"prompt_id": pid,
"status": "complete",
})
await bus.broadcast_to_user(user_label, "generation_complete", {
"prompt_id": pid,
"seed": seed_used,
"image_count": len(images),
"video_count": len(videos),
})
asyncio.create_task(_run())
return {
"queued": True,
"queue_position": depth + 1,
"message": "Generation submitted to ComfyUI",
}
@router.post("/workflow-gen")
async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_auth)):
"""Submit workflow-based generation(s) to ComfyUI."""
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI client not available")
user_label: str = user["sub"]
bus = get_bus()
registry = get_user_registry()
count = max(1, min(body.count, 20)) # cap at 20
async def _run_one():
# Use the user's own state and template
if registry:
user_sm = registry.get_state_manager(user_label)
user_template = registry.get_workflow_template(user_label)
else:
user_sm = comfy.state_manager
user_template = comfy.workflow_manager.get_workflow_template()
if not user_template:
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": "No workflow template loaded"
})
return
overrides = user_sm.get_overrides()
if body.overrides:
overrides = {**overrides, **body.overrides}
import uuid
pid = str(uuid.uuid4())
def on_progress(node, pid_):
asyncio.create_task(bus.broadcast("node_executing", {
"node": node, "prompt_id": pid_
}))
workflow, applied = comfy.inspector.inject_overrides(user_template, overrides)
seed_used = applied.get("seed")
comfy.last_seed = seed_used
try:
images, videos = await comfy._general_generate(workflow, pid, on_progress)
except Exception as exc:
logger.exception("Workflow gen error")
await bus.broadcast_to_user(user_label, "generation_error", {
"prompt_id": None, "error": str(exc)
})
return
comfy.last_prompt_id = pid
comfy.total_generated += 1
config = get_config()
try:
from generation_db import record_generation, record_file
gen_id = record_generation(pid, "web", user_label, overrides, seed_used)
for i, img_data in enumerate(images):
record_file(gen_id, f"image_{i:04d}.png", img_data)
if config and videos:
for vid in videos:
vsub = vid.get("video_subfolder", "")
vname = vid.get("video_name", "")
vpath = (
Path(config.comfy_output_path) / vsub / vname
if vsub
else Path(config.comfy_output_path) / vname
)
try:
record_file(gen_id, vname, vpath.read_bytes())
except OSError:
pass
except Exception as exc:
logger.warning("Failed to record generation to DB: %s", exc)
if config:
from media_uploader import flush_pending
asyncio.create_task(flush_pending(
Path(config.comfy_output_path),
config.media_upload_user,
config.media_upload_pass,
))
await bus.broadcast("queue_update", {"prompt_id": pid, "status": "complete"})
await bus.broadcast_to_user(user_label, "generation_complete", {
"prompt_id": pid,
"seed": seed_used,
"image_count": len(images),
"video_count": len(videos),
})
depth = await comfy.get_queue_depth()
for _ in range(count):
asyncio.create_task(_run_one())
return {
"queued": True,
"count": count,
"queue_position": depth + 1,
}

View File

@@ -0,0 +1,143 @@
"""GET /api/history; GET /api/history/{prompt_id}/images; GET /api/history/{prompt_id}/file/{filename};
POST /api/history/{prompt_id}/share; DELETE /api/history/{prompt_id}/share"""
from __future__ import annotations
import base64
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import Response
from web.auth import require_auth
router = APIRouter()
def _assert_owns(prompt_id: str, user: dict) -> None:
"""Raise 404 if the generation doesn't exist or doesn't belong to the user.
Returning the same 404 for both cases prevents leaking whether a
prompt_id exists to users who don't own it. Admins bypass this check.
"""
if user.get("admin"):
return
from generation_db import get_generation
gen = get_generation(prompt_id)
if gen is None or gen["user_label"] != user["sub"]:
raise HTTPException(404, "Not found")
@router.get("")
async def get_history(
user: dict = Depends(require_auth),
q: Optional[str] = Query(None, description="Keyword to search in overrides JSON"),
):
"""Return generation history. Admins see all; regular users see only their own.
Pass ?q=keyword to filter by prompt text or any override field."""
from generation_db import (
get_history as db_get_history,
get_history_for_user,
search_history,
search_history_for_user,
)
if q and q.strip():
if user.get("admin"):
return {"history": search_history(q.strip(), limit=50)}
return {"history": search_history_for_user(user["sub"], q.strip(), limit=50)}
if user.get("admin"):
return {"history": db_get_history(limit=50)}
return {"history": get_history_for_user(user["sub"], limit=50)}
@router.get("/{prompt_id}/images")
async def get_history_images(prompt_id: str, user: dict = Depends(require_auth)):
"""
Fetch output files for a past generation.
Returns base64-encoded blobs from the local SQLite DB — works even after
``flush_pending`` has deleted the files from disk.
"""
_assert_owns(prompt_id, user)
from generation_db import get_files
files = get_files(prompt_id)
if not files:
raise HTTPException(404, f"No files found for prompt_id {prompt_id!r}")
return {
"prompt_id": prompt_id,
"images": [
{
"filename": f["filename"],
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
"mime_type": f["mime_type"],
}
for f in files
],
}
@router.get("/{prompt_id}/file/{filename}")
async def get_history_file(
prompt_id: str,
filename: str,
request: Request,
user: dict = Depends(require_auth),
):
"""Stream a single output file, with HTTP range request support for video seeking."""
_assert_owns(prompt_id, user)
from generation_db import get_files
files = get_files(prompt_id)
matched = next((f for f in files if f["filename"] == filename), None)
if matched is None:
raise HTTPException(404, f"File {filename!r} not found for prompt_id {prompt_id!r}")
data: bytes = matched["data"]
mime: str = matched["mime_type"]
total = len(data)
range_header = request.headers.get("range")
if range_header:
range_val = range_header.replace("bytes=", "")
start_str, _, end_str = range_val.partition("-")
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else total - 1
end = min(end, total - 1)
chunk = data[start : end + 1]
return Response(
content=chunk,
status_code=206,
media_type=mime,
headers={
"Content-Range": f"bytes {start}-{end}/{total}",
"Accept-Ranges": "bytes",
"Content-Length": str(len(chunk)),
},
)
return Response(
content=data,
media_type=mime,
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total),
},
)
@router.post("/{prompt_id}/share")
async def create_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
"""Create a share token for a generation. Only the owner may share."""
# Use the same 404-for-everything helper to avoid leaking prompt_id existence
_assert_owns(prompt_id, user)
from generation_db import create_share as db_create_share
token = db_create_share(prompt_id, user["sub"])
return {"share_token": token}
@router.delete("/{prompt_id}/share")
async def revoke_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
"""Revoke a share token for a generation. Only the owner may revoke."""
from generation_db import revoke_share as db_revoke_share
deleted = db_revoke_share(prompt_id, user["sub"])
if not deleted:
raise HTTPException(404, "No active share found for this generation")
return {"ok": True}

View File

@@ -0,0 +1,194 @@
"""GET/POST/DELETE /api/inputs; GET /api/inputs/{id}/image; POST /api/inputs/{id}/activate"""
from __future__ import annotations
import logging
import mimetypes
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, UploadFile
from fastapi.responses import Response
from web.auth import require_auth
from web.deps import get_config, get_user_registry
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("")
async def list_inputs(_: dict = Depends(require_auth)):
"""List all input images (Discord + web uploads)."""
from input_image_db import get_all_images
rows = get_all_images()
return [dict(r) for r in rows]
@router.post("")
async def upload_input(
file: UploadFile = File(...),
slot_key: Optional[str] = Form(default=None),
user: dict = Depends(require_auth),
):
"""
Upload an input image.
Stores image bytes directly in SQLite. If *slot_key* is provided the
image is immediately activated for that slot (writes to ComfyUI input
folder and updates the user's state override).
The physical slot file uses a namespaced key ``<user_label>_<slot_key>``
so concurrent users each get their own active image file.
"""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
data = await file.read()
filename = file.filename or "upload.png"
from input_image_db import upsert_image, activate_image_for_slot
row_id = upsert_image(
original_message_id=0, # sentinel for web uploads
bot_reply_id=0,
channel_id=0,
filename=filename,
image_data=data,
)
activated_filename: str | None = None
if slot_key:
user_label: str = user["sub"]
namespaced_key = f"{user_label}_{slot_key}"
activated_filename = activate_image_for_slot(
row_id, namespaced_key, config.comfy_input_path
)
registry = get_user_registry()
if registry:
registry.get_state_manager(user_label).set_override(slot_key, activated_filename)
else:
from web.deps import get_comfy
comfy = get_comfy()
if comfy:
comfy.state_manager.set_override(slot_key, activated_filename)
return {"id": row_id, "filename": filename, "slot_key": slot_key, "activated_filename": activated_filename}
@router.post("/{row_id}/activate")
async def activate_input(
row_id: int,
slot_key: str = Body(default="input_image", embed=True),
user: dict = Depends(require_auth),
):
"""Write the stored image to the ComfyUI input folder and set the user's slot override."""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
from input_image_db import get_image, activate_image_for_slot
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
user_label: str = user["sub"]
namespaced_key = f"{user_label}_{slot_key}"
try:
filename = activate_image_for_slot(row_id, namespaced_key, config.comfy_input_path)
except ValueError as exc:
raise HTTPException(409, str(exc))
registry = get_user_registry()
if registry:
registry.get_state_manager(user_label).set_override(slot_key, filename)
else:
from web.deps import get_comfy
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "State manager not available")
comfy.state_manager.set_override(slot_key, filename)
return {"ok": True, "slot_key": slot_key, "filename": filename}
@router.delete("/{row_id}")
async def delete_input(row_id: int, _: dict = Depends(require_auth)):
"""Delete an input image record (and its active slot file if applicable)."""
from input_image_db import get_image, delete_image
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
config = get_config()
delete_image(row_id, comfy_input_path=config.comfy_input_path if config else None)
return {"ok": True}
@router.get("/{row_id}/image")
async def get_input_image(row_id: int, _: dict = Depends(require_auth)):
"""Serve the raw image bytes stored in the database for a given input image row."""
from input_image_db import get_image, get_image_data
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
data = get_image_data(row_id)
if data is None:
raise HTTPException(404, "Image data not available — re-upload to backfill")
mime, _ = mimetypes.guess_type(row["filename"])
return Response(content=data, media_type=mime or "application/octet-stream")
def _pil_resize_response(data: bytes, filename: str, max_size: int, quality: int) -> Response:
"""Resize image bytes with Pillow and return a JPEG Response. Raises on failure."""
import io
from PIL import Image as _PIL
img = _PIL.open(io.BytesIO(data))
img.thumbnail((max_size, max_size), _PIL.LANCZOS)
buf = io.BytesIO()
img.convert("RGB").save(buf, "JPEG", quality=quality, optimize=True)
return Response(
content=buf.getvalue(),
media_type="image/jpeg",
headers={"Cache-Control": "public, max-age=86400"},
)
@router.get("/{row_id}/thumb")
async def get_input_thumb(row_id: int, _: dict = Depends(require_auth)):
"""Serve a small compressed thumbnail (max 200 px, JPEG 65 %) for fast previews."""
from input_image_db import get_image, get_image_data
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
data = get_image_data(row_id)
if data is None:
raise HTTPException(404, "Image data not available — re-upload to backfill")
try:
return _pil_resize_response(data, row["filename"], max_size=200, quality=65)
except Exception:
mime, _ = mimetypes.guess_type(row["filename"])
return Response(content=data, media_type=mime or "application/octet-stream")
@router.get("/{row_id}/mid")
async def get_input_mid(row_id: int, _: dict = Depends(require_auth)):
"""Serve a medium compressed image (max 800 px, JPEG 80 %) for progressive loading."""
from input_image_db import get_image, get_image_data
row = get_image(row_id)
if row is None:
raise HTTPException(404, "Image not found")
data = get_image_data(row_id)
if data is None:
raise HTTPException(404, "Image data not available — re-upload to backfill")
try:
return _pil_resize_response(data, row["filename"], max_size=800, quality=80)
except Exception:
mime, _ = mimetypes.guess_type(row["filename"])
return Response(content=data, media_type=mime or "application/octet-stream")

View File

@@ -0,0 +1,153 @@
"""CRUD for workflow presets via /api/presets"""
from __future__ import annotations
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from web.auth import require_auth
from web.deps import get_comfy, get_user_registry
router = APIRouter()
class SavePresetRequest(BaseModel):
name: str
description: Optional[str] = None
class SaveFromHistoryRequest(BaseModel):
name: str
description: Optional[str] = None
def _get_pm():
from web.deps import get_bot
bot = get_bot()
pm = getattr(bot, "preset_manager", None) if bot else None
if pm is None:
from preset_manager import PresetManager
pm = PresetManager()
return pm
@router.get("")
async def list_presets(_: dict = Depends(require_auth)):
pm = _get_pm()
return {"presets": pm.list_preset_details()}
@router.post("")
async def save_preset(body: SavePresetRequest, user: dict = Depends(require_auth)):
"""Capture the user's overrides + workflow template as a named preset."""
user_label: str = user["sub"]
registry = get_user_registry()
pm = _get_pm()
if registry:
workflow_template = registry.get_workflow_template(user_label)
overrides = registry.get_state_manager(user_label).get_overrides()
else:
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI not available")
workflow_template = comfy.get_workflow_template()
overrides = comfy.state_manager.get_overrides()
try:
pm.save(body.name, workflow_template, overrides, owner=user_label, description=body.description)
except ValueError as exc:
raise HTTPException(400, str(exc))
return {"ok": True, "name": body.name}
@router.get("/{name}")
async def get_preset(name: str, _: dict = Depends(require_auth)):
pm = _get_pm()
data = pm.load(name)
if data is None:
raise HTTPException(404, "Preset not found")
return data
@router.post("/{name}/load")
async def load_preset(name: str, user: dict = Depends(require_auth)):
"""Restore overrides (and optionally workflow template) from a preset into the user's state."""
pm = _get_pm()
data = pm.load(name)
if data is None:
raise HTTPException(404, "Preset not found")
user_label: str = user["sub"]
registry = get_user_registry()
if registry:
wf = data.get("workflow")
if wf:
registry.set_workflow(user_label, wf, name)
else:
# No workflow in preset — just clear overrides and restore state
registry.get_state_manager(user_label).clear_overrides()
state = data.get("state", {})
sm = registry.get_state_manager(user_label)
for k, v in state.items():
if v is not None:
sm.set_override(k, v)
else:
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI not available")
comfy.state_manager.clear_overrides()
state = data.get("state", {})
for k, v in state.items():
if v is not None:
comfy.state_manager.set_override(k, v)
wf = data.get("workflow")
if wf:
comfy.workflow_manager.set_workflow_template(wf)
return {"ok": True, "name": name, "overrides_restored": list(data.get("state", {}).keys())}
@router.delete("/{name}")
async def delete_preset(name: str, user: dict = Depends(require_auth)):
pm = _get_pm()
data = pm.load(name)
if data is None:
raise HTTPException(404, "Preset not found")
user_label: str = user["sub"]
is_admin = user.get("admin") is True
owner = data.get("owner")
if owner is not None and owner != user_label and not is_admin:
raise HTTPException(403, "You do not have permission to delete this preset")
pm.delete(name)
return {"ok": True}
@router.post("/from-history/{prompt_id}")
async def save_preset_from_history(
prompt_id: str,
body: SaveFromHistoryRequest,
user: dict = Depends(require_auth),
):
"""Create a preset from a past generation's overrides."""
from generation_db import get_generation_full
gen = get_generation_full(prompt_id)
if gen is None:
raise HTTPException(404, "Generation not found")
user_label: str = user["sub"]
is_admin = user.get("admin") is True
if not is_admin and gen.get("user_label") != user_label:
raise HTTPException(404, "Generation not found")
pm = _get_pm()
try:
pm.save(body.name, None, gen["overrides"], owner=user_label, description=body.description)
except ValueError as exc:
raise HTTPException(400, str(exc))
return {"ok": True, "name": body.name}

View File

@@ -0,0 +1,90 @@
"""GET/POST /api/server/{action}; GET /api/logs/tail"""
from __future__ import annotations
import logging
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException
from web.auth import require_auth
from web.deps import get_config, get_comfy
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/server/status")
async def server_status(_: dict = Depends(require_auth)):
"""Return NSSM service state and HTTP health."""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
from commands.server import get_service_state
import asyncio
async def _false():
return False
comfy = get_comfy()
service_state, http_ok = await asyncio.gather(
get_service_state(config.comfy_service_name),
comfy.check_connection() if comfy else _false(),
)
return {"service_state": service_state, "http_reachable": http_ok}
@router.post("/server/{action}")
async def server_action(action: str, _: dict = Depends(require_auth)):
"""Control the ComfyUI service: start | stop | restart | install | uninstall"""
config = get_config()
if config is None:
raise HTTPException(503, "Config not available")
valid_actions = {"start", "stop", "restart", "install", "uninstall"}
if action not in valid_actions:
raise HTTPException(400, f"Invalid action '{action}'")
from commands.server import _nssm, _install_service
import asyncio
try:
if action == "install":
ok, msg = await _install_service(config)
if not ok:
raise HTTPException(500, msg)
elif action == "uninstall":
await _nssm("stop", config.comfy_service_name)
await _nssm("remove", config.comfy_service_name, "confirm")
elif action == "start":
await _nssm("start", config.comfy_service_name)
elif action == "stop":
await _nssm("stop", config.comfy_service_name)
elif action == "restart":
await _nssm("restart", config.comfy_service_name)
except HTTPException:
raise
except Exception as exc:
raise HTTPException(500, str(exc))
return {"ok": True, "action": action}
@router.get("/logs/tail")
async def tail_logs(lines: int = 100, _: dict = Depends(require_auth)):
"""Tail the ComfyUI log file."""
config = get_config()
if config is None or not config.comfy_log_dir:
raise HTTPException(503, "Log directory not configured")
log_dir = Path(config.comfy_log_dir)
log_file = log_dir / "comfyui.log"
if not log_file.exists():
return {"lines": []}
try:
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
all_lines = f.readlines()
tail = all_lines[-min(lines, len(all_lines)):]
return {"lines": [ln.rstrip("\n") for ln in tail]}
except Exception as exc:
raise HTTPException(500, str(exc))

View File

@@ -0,0 +1,85 @@
"""GET /api/share/{token}; GET /api/share/{token}/file/{filename}"""
from __future__ import annotations
import base64
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from web.auth import require_auth
router = APIRouter()
@router.get("/{token}")
async def get_share(token: str, _: dict = Depends(require_auth)):
"""Fetch share metadata and images. Any authenticated user may view a valid share link."""
from generation_db import get_share_by_token, get_files
gen = get_share_by_token(token)
if gen is None:
raise HTTPException(404, "Share not found or revoked")
files = get_files(gen["prompt_id"])
return {
"prompt_id": gen["prompt_id"],
"created_at": gen["created_at"],
"overrides": gen["overrides"],
"seed": gen["seed"],
"images": [
{
"filename": f["filename"],
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
"mime_type": f["mime_type"],
}
for f in files
],
}
@router.get("/{token}/file/{filename}")
async def get_share_file(
token: str,
filename: str,
request: Request,
_: dict = Depends(require_auth),
):
"""Stream a single output file via share token, with HTTP range support for video seeking."""
from generation_db import get_share_by_token, get_files
gen = get_share_by_token(token)
if gen is None:
raise HTTPException(404, "Share not found or revoked")
files = get_files(gen["prompt_id"])
matched = next((f for f in files if f["filename"] == filename), None)
if matched is None:
raise HTTPException(404, f"File {filename!r} not found")
data: bytes = matched["data"]
mime: str = matched["mime_type"]
total = len(data)
range_header = request.headers.get("range")
if range_header:
range_val = range_header.replace("bytes=", "")
start_str, _, end_str = range_val.partition("-")
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else total - 1
end = min(end, total - 1)
chunk = data[start : end + 1]
return Response(
content=chunk,
status_code=206,
media_type=mime,
headers={
"Content-Range": f"bytes {start}-{end}/{total}",
"Accept-Ranges": "bytes",
"Content-Length": str(len(chunk)),
},
)
return Response(
content=data,
media_type=mime,
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total),
},
)

View File

@@ -0,0 +1,53 @@
"""GET/PUT /api/state; DELETE /api/state/{key}"""
from __future__ import annotations
from typing import Any, Dict
from fastapi import APIRouter, Depends, HTTPException
from web.auth import require_auth
from web.deps import get_config, get_user_registry
router = APIRouter()
def _get_user_sm(user: dict):
"""Return the per-user WorkflowStateManager, raising 503 if unavailable."""
registry = get_user_registry()
if registry is None:
raise HTTPException(503, "State manager not available")
return registry.get_state_manager(user["sub"])
@router.get("/state")
async def get_state(user: dict = Depends(require_auth)):
"""Return all current overrides for the authenticated user."""
sm = _get_user_sm(user)
return sm.get_overrides()
@router.put("/state")
async def put_state(body: Dict[str, Any], user: dict = Depends(require_auth)):
"""Merge override values. Pass ``null`` as a value to delete a key."""
sm = _get_user_sm(user)
for key, value in body.items():
if value is None:
sm.delete_override(key)
else:
sm.set_override(key, value)
return sm.get_overrides()
@router.delete("/state/{key}")
async def delete_state_key(key: str, user: dict = Depends(require_auth)):
"""Remove a single override key, and clean up any associated slot file."""
sm = _get_user_sm(user)
sm.delete_override(key)
config = get_config()
if config:
from input_image_db import deactivate_image_slot
user_label: str = user["sub"]
deactivate_image_slot(f"{user_label}_{key}", config.comfy_input_path)
return {"ok": True, "key": key}

View File

@@ -0,0 +1,74 @@
"""GET /api/status — polling fallback for clients that can't use WebSocket"""
from __future__ import annotations
import asyncio
from fastapi import APIRouter, Depends
from web.auth import require_auth
from web.deps import get_bot, get_comfy, get_config
router = APIRouter()
@router.get("/status")
async def get_status(_: dict = Depends(require_auth)):
"""Return a full status snapshot."""
bot = get_bot()
comfy = get_comfy()
config = get_config()
snap: dict = {}
if bot is not None:
import datetime as _dt
lat = bot.latency
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
start = getattr(bot, "start_time", None)
uptime = ""
if start:
delta = _dt.datetime.now(_dt.timezone.utc) - start
total = int(delta.total_seconds())
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
snap["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
if comfy is not None:
q_task = asyncio.create_task(comfy.get_comfy_queue())
conn_task = asyncio.create_task(comfy.check_connection())
q, reachable = await asyncio.gather(q_task, conn_task)
pending = len(q.get("queue_pending", [])) if q else 0
running = len(q.get("queue_running", [])) if q else 0
wm = getattr(comfy, "workflow_manager", None)
wf_loaded = wm is not None and wm.get_workflow_template() is not None
snap["comfy"] = {
"server": comfy.server_address,
"reachable": reachable,
"queue_pending": pending,
"queue_running": running,
"workflow_loaded": wf_loaded,
"last_seed": comfy.last_seed,
"total_generated": comfy.total_generated,
}
snap["overrides"] = comfy.state_manager.get_overrides()
if config is not None:
from commands.server import get_service_state
service_state = await get_service_state(config.comfy_service_name)
snap["service"] = {"state": service_state}
try:
from media_uploader import get_stats as us_fn, is_running as ur_fn
us = us_fn()
snap["upload"] = {
"configured": bool(config.media_upload_user),
"running": ur_fn(),
"total_ok": us.total_ok,
"total_fail": us.total_fail,
}
except Exception:
pass
return snap

View File

@@ -0,0 +1,175 @@
"""
GET /api/workflow — current workflow info
GET /api/workflow/inputs — dynamic NodeInput list
GET /api/workflow/files — list files in workflows/
POST /api/workflow/upload — upload a workflow JSON
POST /api/workflow/load — load a workflow from workflows/
GET /api/workflow/models?type=checkpoints|loras — available models (60s TTL cache)
"""
from __future__ import annotations
import json
import logging
import time
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from web.auth import require_auth
from web.deps import get_comfy, get_config, get_inspector, get_user_registry
router = APIRouter()
logger = logging.getLogger(__name__)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
_WORKFLOWS_DIR = _PROJECT_ROOT / "workflows"
# Simple in-memory TTL cache for models
_models_cache: dict = {}
_MODELS_TTL = 60.0
@router.get("")
async def get_workflow(user: dict = Depends(require_auth)):
"""Return basic info about the currently loaded workflow."""
user_label: str = user["sub"]
registry = get_user_registry()
if registry:
template = registry.get_workflow_template(user_label)
last_wf = registry.get_state_manager(user_label).get_last_workflow_file()
else:
# Fallback to global state when registry is unavailable
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "Workflow manager not available")
template = comfy.workflow_manager.get_workflow_template()
last_wf = comfy.state_manager.get_last_workflow_file()
return {
"loaded": template is not None,
"node_count": len(template) if template else 0,
"last_workflow_file": last_wf,
}
@router.get("/inputs")
async def get_workflow_inputs(user: dict = Depends(require_auth)):
"""Return dynamic NodeInput list (common + advanced) for the current workflow."""
user_label: str = user["sub"]
inspector = get_inspector()
if inspector is None:
raise HTTPException(503, "Workflow components not available")
registry = get_user_registry()
if registry:
template = registry.get_workflow_template(user_label)
overrides = registry.get_state_manager(user_label).get_overrides()
else:
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "Workflow components not available")
template = comfy.workflow_manager.get_workflow_template()
overrides = comfy.state_manager.get_overrides()
if template is None:
return {"common": [], "advanced": []}
inputs = inspector.inspect(template)
result = []
for ni in inputs:
val = overrides.get(ni.key, ni.current_value)
result.append({
"key": ni.key,
"label": ni.label,
"input_type": ni.input_type,
"current_value": val,
"node_class": ni.node_class,
"node_title": ni.node_title,
"is_common": ni.is_common,
})
common = [r for r in result if r["is_common"]]
advanced = [r for r in result if not r["is_common"]]
return {"common": common, "advanced": advanced}
@router.get("/files")
async def list_workflow_files(_: dict = Depends(require_auth)):
"""List .json files in the workflows/ folder."""
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
files = sorted(p.name for p in _WORKFLOWS_DIR.glob("*.json"))
return {"files": files}
@router.post("/upload")
async def upload_workflow(
file: UploadFile = File(...),
_: dict = Depends(require_auth),
):
"""Upload a workflow JSON to the workflows/ folder."""
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
filename = file.filename or "workflow.json"
if not filename.endswith(".json"):
filename += ".json"
data = await file.read()
try:
json.loads(data) # validate JSON
except json.JSONDecodeError as exc:
raise HTTPException(400, f"Invalid JSON: {exc}")
dest = _WORKFLOWS_DIR / filename
dest.write_bytes(data)
return {"ok": True, "filename": filename}
@router.post("/load")
async def load_workflow(filename: str = Form(...), user: dict = Depends(require_auth)):
"""Load a workflow from the workflows/ folder into the user's isolated state."""
wf_path = _WORKFLOWS_DIR / filename
if not wf_path.exists():
raise HTTPException(404, f"Workflow file '{filename}' not found")
try:
with open(wf_path, "r", encoding="utf-8") as f:
workflow = json.load(f)
except Exception as exc:
raise HTTPException(500, str(exc))
registry = get_user_registry()
if registry:
registry.set_workflow(user["sub"], workflow, filename)
else:
# Fallback: update global state when registry unavailable
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI not available")
comfy.workflow_manager.set_workflow_template(workflow)
comfy.state_manager.clear_overrides()
comfy.state_manager.set_last_workflow_file(filename)
inspector = get_inspector()
node_count = len(workflow)
inputs_count = len(inspector.inspect(workflow)) if inspector else 0
return {
"ok": True,
"filename": filename,
"node_count": node_count,
"inputs_count": inputs_count,
}
@router.get("/models")
async def get_models(type: str = "checkpoints", _: dict = Depends(require_auth)):
"""Return available model names from ComfyUI (60s TTL cache)."""
global _models_cache
now = time.time()
cache_key = type
cached = _models_cache.get(cache_key)
if cached and (now - cached["ts"]) < _MODELS_TTL:
return {"type": type, "models": cached["models"]}
comfy = get_comfy()
if comfy is None:
raise HTTPException(503, "ComfyUI not available")
models = await comfy.get_models(type)
_models_cache[cache_key] = {"models": models, "ts": now}
return {"type": type, "models": models}

61
web/routers/ws_router.py Normal file
View File

@@ -0,0 +1,61 @@
"""WebSocket /ws?token=<jwt> — real-time event stream"""
from __future__ import annotations
import asyncio
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from web.auth import verify_ws_token
from web.ws_bus import get_bus
router = APIRouter()
logger = logging.getLogger(__name__)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = ""):
"""
Authenticate via JWT query param or ttb_session cookie, then stream events from WSBus.
Events common to all users: status_snapshot, queue_update, node_executing, server_state
Events private to submitter: generation_complete, generation_error
"""
payload = verify_ws_token(token)
if payload is None:
# Fallback: browsers send cookies automatically with WebSocket connections
cookie_token = websocket.cookies.get("ttb_session", "")
payload = verify_ws_token(cookie_token)
if payload is None:
await websocket.close(code=4001, reason="Unauthorized")
return
user_label: str = payload.get("sub", "anonymous")
bus = get_bus()
queue = bus.subscribe(user_label)
await websocket.accept()
logger.info("WS connected: user=%s", user_label)
try:
while True:
# Wait for an event from the bus
try:
frame = await asyncio.wait_for(queue.get(), timeout=30.0)
except asyncio.TimeoutError:
# Send a keepalive ping
try:
await websocket.send_text('{"type":"ping"}')
except Exception:
break
continue
try:
await websocket.send_text(frame)
except Exception:
break
except WebSocketDisconnect:
pass
finally:
bus.unsubscribe(user_label, queue)
logger.info("WS disconnected: user=%s", user_label)

122
web/ws_bus.py Normal file
View File

@@ -0,0 +1,122 @@
"""
web/ws_bus.py
=============
In-process WebSocket event bus.
All connected web clients share a single WSBus instance. Events are
delivered per-user (private results) or to all users (shared status).
Usage::
bus = WSBus()
# Subscribe (returns a queue; caller reads from it)
q = bus.subscribe("alice")
# Broadcast to all
await bus.broadcast("status_snapshot", {...})
# Broadcast to one user (all their open tabs)
await bus.broadcast_to_user("alice", "generation_complete", {...})
# Unsubscribe when WS disconnects
bus.unsubscribe("alice", q)
Event frame format sent on wire:
{"type": "event_name", "data": {...}, "ts": 1234567890.123}
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import Any, Dict, Set
logger = logging.getLogger(__name__)
class WSBus:
"""
Per-user broadcast bus backed by asyncio queues.
Thread-safe as long as all callers run in the same event loop.
"""
def __init__(self) -> None:
# user_label → set of asyncio.Queue
self._clients: Dict[str, Set[asyncio.Queue]] = {}
# ------------------------------------------------------------------
# Subscription lifecycle
# ------------------------------------------------------------------
def subscribe(self, user_label: str) -> asyncio.Queue:
"""Register a new client connection. Returns the queue to read from."""
q: asyncio.Queue = asyncio.Queue(maxsize=256)
self._clients.setdefault(user_label, set()).add(q)
logger.debug("WSBus: %s subscribed (%d queues)", user_label,
len(self._clients[user_label]))
return q
def unsubscribe(self, user_label: str, queue: asyncio.Queue) -> None:
"""Remove a client connection."""
queues = self._clients.get(user_label, set())
queues.discard(queue)
if not queues:
self._clients.pop(user_label, None)
logger.debug("WSBus: %s unsubscribed", user_label)
@property
def connected_users(self) -> list[str]:
"""List of user labels with at least one active connection."""
return list(self._clients.keys())
@property
def total_connections(self) -> int:
return sum(len(qs) for qs in self._clients.values())
# ------------------------------------------------------------------
# Broadcasting
# ------------------------------------------------------------------
def _frame(self, event_type: str, data: Any) -> str:
return json.dumps({"type": event_type, "data": data, "ts": time.time()})
async def broadcast(self, event_type: str, data: Any) -> None:
"""Send an event to ALL connected clients."""
frame = self._frame(event_type, data)
for queues in list(self._clients.values()):
for q in list(queues):
try:
q.put_nowait(frame)
except asyncio.QueueFull:
logger.warning("WSBus: queue full, dropping %s event", event_type)
async def broadcast_to_user(
self, user_label: str, event_type: str, data: Any
) -> None:
"""Send an event to all connections belonging to *user_label*."""
queues = self._clients.get(user_label, set())
if not queues:
logger.debug("WSBus: no clients for user '%s', dropping %s", user_label, event_type)
return
frame = self._frame(event_type, data)
for q in list(queues):
try:
q.put_nowait(frame)
except asyncio.QueueFull:
logger.warning("WSBus: queue full for %s, dropping %s", user_label, event_type)
# Module-level singleton (set by web/app.py)
_bus: WSBus | None = None
def get_bus() -> WSBus:
global _bus
if _bus is None:
_bus = WSBus()
return _bus

397
workflow_inspector.py Normal file
View File

@@ -0,0 +1,397 @@
"""
workflow_inspector.py
=====================
Dynamic workflow node inspection and injection for the Discord ComfyUI bot.
Replaces the hardcoded node-finding methods in WorkflowManager with a
general-purpose inspector that works with any ComfyUI workflow. The
inspector discovers injectable inputs at load time by walking the workflow
JSON, classifying each scalar input by class_type + input_name, and
assigning stable human-readable keys (e.g. ``"prompt"``, ``"seed"``,
``"input_image"``) that can be stored in WorkflowStateManager and used by
both Discord commands and the web UI.
Key-assignment rules
--------------------
- CLIPTextEncode / text → ``"prompt"`` / ``"negative_prompt"`` / ``"text_{node_id}"``
- LoadImage / image → ``"input_image"`` (first), ``{title_slug}`` (subsequent)
- KSampler* / seed → ``"seed"``
- KSampler / steps|cfg|sampler_name|scheduler|denoise → same name
- EmptyLatentImage / width|height → ``"width"`` / ``"height"``
- CheckpointLoaderSimple / ckpt_name → ``"checkpoint"``
- LoraLoader / lora_name → ``"lora_{node_id}"``
- Other scalars → ``"{class_slug}_{node_id}_{input_name}"``
"""
from __future__ import annotations
import copy
import random
import re
from dataclasses import dataclass
from typing import Any, Optional
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _slugify(text: str) -> str:
"""Convert arbitrary text to a lowercase underscore key."""
text = text.lower()
text = re.sub(r"[^a-z0-9]+", "_", text)
return text.strip("_") or "node"
def _numeric_sort_key(node_id: str):
"""Sort node IDs numerically where possible."""
try:
return (0, int(node_id))
except (ValueError, TypeError):
return (1, node_id)
# ---------------------------------------------------------------------------
# NodeInput dataclass
# ---------------------------------------------------------------------------
@dataclass
class NodeInput:
"""
Descriptor for a single injectable input within a ComfyUI workflow.
Attributes
----------
node_id : str
The node's key in the workflow dict.
node_class : str
``class_type`` of the node (e.g. ``"KSampler"``).
node_title : str
Value of ``_meta.title`` for the node (may be empty).
input_name : str
The input field name within the node's ``inputs`` dict.
input_type : str
Semantic type: ``"text"``, ``"seed"``, ``"image"``, ``"integer"``,
``"float"``, ``"string"``, ``"checkpoint"``, ``"lora"``.
current_value : Any
The value currently stored in the workflow template.
label : str
Human-readable display label (``"NodeTitle / input_name"``).
key : str
Stable short key used by Discord commands and state storage.
is_common : bool
True for prompts, seeds, and image inputs (shown prominently in UI).
"""
node_id: str
node_class: str
node_title: str
input_name: str
input_type: str
current_value: Any
label: str
key: str
is_common: bool
# ---------------------------------------------------------------------------
# WorkflowInspector
# ---------------------------------------------------------------------------
class WorkflowInspector:
"""
Inspects and modifies ComfyUI workflow JSON.
Usage::
inspector = WorkflowInspector()
inputs = inspector.inspect(workflow)
modified, applied = inspector.inject_overrides(workflow, {"prompt": "a cat"})
"""
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def inspect(self, workflow: dict[str, Any]) -> list[NodeInput]:
"""
Walk a workflow and return all injectable scalar inputs.
Parameters
----------
workflow : dict
ComfyUI workflow in API format (node_id → node_dict).
Returns
-------
list[NodeInput]
All injectable inputs, common ones first, advanced sorted by key.
"""
inputs: list[NodeInput] = []
load_image_count = 0 # for unique LoadImage key assignment
for node_id, node in sorted(workflow.items(), key=lambda kv: _numeric_sort_key(kv[0])):
if not isinstance(node, dict):
continue
class_type: str = node.get("class_type", "")
title: str = node.get("_meta", {}).get("title", "") or ""
node_inputs = node.get("inputs", {})
if not isinstance(node_inputs, dict):
continue
for input_name, value in node_inputs.items():
# Skip node-reference inputs (they are lists like [node_id, output_slot])
if isinstance(value, list):
continue
# Skip None values — no useful type info
if value is None:
continue
ni = self._classify_input(
node_id=node_id,
class_type=class_type,
title=title,
input_name=input_name,
value=value,
load_image_count=load_image_count,
)
if ni is not None:
if ni.input_type == "image":
load_image_count += 1
inputs.append(ni)
# Sort: common first, then advanced (both groups sorted by key for stability)
inputs.sort(key=lambda x: (0 if x.is_common else 1, x.key))
return inputs
def inject_overrides(
self,
workflow: dict[str, Any],
overrides: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Deep-copy a workflow and inject override values.
Seeds that are absent from *overrides* or set to ``-1`` are
auto-randomized. All other injectable keys found via
:meth:`inspect` are updated if present in *overrides*.
Parameters
----------
workflow : dict
The workflow template (not mutated).
overrides : dict
Mapping of ``NodeInput.key → value`` to inject.
Returns
-------
tuple[dict, dict]
``(modified_workflow, applied_values)`` where *applied_values*
maps each key that was actually written (including auto-generated
seeds) to the value that was used.
"""
wf = copy.deepcopy(workflow)
applied: dict[str, Any] = {}
# Inspect the deep copy to build key → [(node_id, input_name), …]
inputs = self.inspect(wf)
# Group targets by key
key_targets: dict[str, list[tuple[str, str]]] = {}
key_itype: dict[str, str] = {}
for ni in inputs:
if ni.key not in key_targets:
key_targets[ni.key] = []
key_itype[ni.key] = ni.input_type
key_targets[ni.key].append((ni.node_id, ni.input_name))
for key, targets in key_targets.items():
itype = key_itype[key]
override_val = overrides.get(key)
if itype == "seed":
# -1 sentinel or absent → auto-randomize
if override_val is None or override_val == -1:
seed = random.randint(0, 2 ** 32 - 1)
else:
seed = int(override_val)
for node_id, input_name in targets:
wf[node_id]["inputs"][input_name] = seed
applied[key] = seed
elif override_val is not None:
for node_id, input_name in targets:
wf[node_id]["inputs"][input_name] = override_val
applied[key] = override_val
return wf, applied
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _classify_input(
self,
*,
node_id: str,
class_type: str,
title: str,
input_name: str,
value: Any,
load_image_count: int,
) -> Optional[NodeInput]:
"""
Return a NodeInput for one field, or None to skip it entirely.
"""
title_display = title or class_type
# ---- CLIPTextEncode → positive/negative prompt ----
if class_type == "CLIPTextEncode" and input_name == "text":
t = title.lower()
if "positive" in t:
key = "prompt"
elif "negative" in t:
key = "negative_prompt"
else:
key = f"text_{node_id}"
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="text",
current_value=value,
label=f"{title_display} / {input_name}",
key=key,
is_common=True,
)
# ---- LoadImage → input image ----
if class_type == "LoadImage" and input_name == "image":
if load_image_count == 0:
key = "input_image"
else:
slug = _slugify(title) if title else f"image_{node_id}"
key = slug or f"image_{node_id}"
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="image",
current_value=value,
label=f"{title_display} / {input_name}",
key=key,
is_common=True,
)
# ---- KSampler / KSamplerAdvanced ----
if class_type in ("KSampler", "KSamplerAdvanced"):
if input_name in ("seed", "noise_seed"):
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="seed",
current_value=value,
label=f"{title_display} / seed",
key="seed",
is_common=True,
)
_ksampler_advanced = {
"steps": ("integer", "steps", False),
"cfg": ("float", "cfg", False),
"sampler_name": ("string", "sampler_name", False),
"scheduler": ("string", "scheduler", False),
"denoise": ("float", "denoise", False),
}
if input_name in _ksampler_advanced:
itype, key, is_common = _ksampler_advanced[input_name]
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type=itype,
current_value=value,
label=f"{title_display} / {input_name}",
key=key,
is_common=is_common,
)
# ---- EmptyLatentImage ----
if class_type == "EmptyLatentImage" and input_name in ("width", "height"):
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="integer",
current_value=value,
label=f"{title_display} / {input_name}",
key=input_name,
is_common=False,
)
# ---- CheckpointLoaderSimple ----
if class_type == "CheckpointLoaderSimple" and input_name == "ckpt_name":
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="checkpoint",
current_value=value,
label=f"{title_display} / checkpoint",
key="checkpoint",
is_common=False,
)
# ---- LoraLoader ----
if class_type == "LoraLoader" and input_name == "lora_name":
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type="lora",
current_value=value,
label=f"{title_display} / lora",
key=f"lora_{node_id}",
is_common=False,
)
# ---- Skip non-scalar or already-handled classes ----
_handled_classes = {
"CLIPTextEncode", "LoadImage", "KSampler", "KSamplerAdvanced",
"EmptyLatentImage", "CheckpointLoaderSimple", "LoraLoader",
}
if class_type in _handled_classes:
return None # unrecognised field for a known class → skip
# ---- Generic scalar fallback ----
if isinstance(value, bool):
return None # booleans aren't useful override targets
if isinstance(value, int):
itype = "integer"
elif isinstance(value, float):
itype = "float"
elif isinstance(value, str):
itype = "string"
else:
return None # dicts, etc. — skip
key = f"{_slugify(class_type)}_{node_id}_{input_name}"
return NodeInput(
node_id=node_id,
node_class=class_type,
node_title=title,
input_name=input_name,
input_type=itype,
current_value=value,
label=f"{title_display} / {input_name}",
key=key,
is_common=False,
)

56
workflow_manager.py Normal file
View File

@@ -0,0 +1,56 @@
"""
workflow_manager.py
===================
Workflow template storage for the Discord ComfyUI bot.
This module provides a WorkflowManager class responsible solely for holding
the loaded workflow template. Node inspection and injection are handled by
:class:`~workflow_inspector.WorkflowInspector`.
"""
from __future__ import annotations
import logging
from typing import Dict, Optional, Any
logger = logging.getLogger(__name__)
class WorkflowManager:
"""
Stores and provides access to the current workflow template.
Parameters
----------
workflow_template : Optional[Dict[str, Any]]
An initial workflow dict. If None, no template is loaded and one
must be set via :meth:`set_workflow_template`.
"""
def __init__(self, workflow_template: Optional[Dict[str, Any]] = None) -> None:
self._workflow_template = workflow_template
def set_workflow_template(self, workflow: Dict[str, Any]) -> None:
"""
Set or replace the workflow template.
Parameters
----------
workflow : Dict[str, Any]
Workflow dictionary in ComfyUI API format.
"""
self._workflow_template = workflow
logger.info("Workflow template updated (%d nodes)", len(workflow))
def get_workflow_template(self) -> Optional[Dict[str, Any]]:
"""
Return the current workflow template.
Returns
-------
Optional[Dict[str, Any]]
The loaded template, or None if none has been set.
"""
return self._workflow_template

252
workflow_state.py Normal file
View File

@@ -0,0 +1,252 @@
"""
workflow_state.py
=================
Workflow state management for the Discord ComfyUI bot.
This module provides a WorkflowStateManager class that stores runtime
overrides for workflow parameters (prompt, negative_prompt, input_image,
seed, steps, cfg, …) in a generic key-value dict. Any NodeInput.key
produced by WorkflowInspector can be set as an override.
The old fixed fields (prompt, negative_prompt, input_image, seed) are
preserved as convenience wrappers for backward compatibility with existing
Discord commands and status_monitor.
State file format (current-workflow-changes.json)::
{
"overrides": {
"prompt": "a beautiful landscape",
"seed": 42,
...
},
"last_workflow_file": "my_workflow.json"
}
The old flat-key format is migrated automatically on first load.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
class WorkflowStateManager:
"""
Manages runtime workflow overrides in memory with optional persistence.
Override keys correspond to ``NodeInput.key`` values discovered by
:class:`~workflow_inspector.WorkflowInspector`. Common well-known keys
are ``"prompt"``, ``"negative_prompt"``, ``"input_image"``, ``"seed"``.
Parameters
----------
state_file : Optional[str]
Path to a JSON file for persisting overrides. Loaded on init if the
file exists; auto-saved on every change.
"""
def __init__(self, state_file: Optional[str] = None) -> None:
self._overrides: Dict[str, Any] = {}
self._last_workflow_file: Optional[str] = None
self._state_file = state_file
if self._state_file:
self._load_from_file()
# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
def _load_from_file(self) -> None:
"""Load state from the configured JSON file if it exists."""
if not self._state_file:
return
state_path = Path(self._state_file)
if not state_path.exists():
logger.debug("State file %s not found, using empty state", self._state_file)
return
try:
with open(state_path, "r", encoding="utf-8") as f:
data = json.load(f)
# New format: {"overrides": {...}, "last_workflow_file": ...}
if "overrides" in data and isinstance(data["overrides"], dict):
self._overrides = data["overrides"]
self._last_workflow_file = data.get("last_workflow_file")
else:
# Migrate old flat format: {"prompt": ..., "negative_prompt": ..., ...}
self._overrides = {
k: v for k, v in data.items()
if v is not None and k not in ("last_workflow_file",)
}
self._last_workflow_file = data.get("last_workflow_file")
logger.info("Loaded workflow state from %s", self._state_file)
except Exception as exc:
logger.warning("Failed to load state from %s: %s", self._state_file, exc)
def save_to_file(self) -> None:
"""
Persist current overrides and last_workflow_file to the state JSON file.
Raises
------
RuntimeError
If no state file was configured.
"""
if not self._state_file:
raise RuntimeError("Cannot save state: no state file configured")
try:
data = {
"overrides": self._overrides,
"last_workflow_file": self._last_workflow_file,
}
with open(self._state_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
logger.debug("Saved workflow state to %s", self._state_file)
except Exception as exc:
logger.error("Failed to save state to %s: %s", self._state_file, exc)
raise
def _autosave(self) -> None:
"""Save to file silently if a state file is configured."""
if self._state_file:
try:
self.save_to_file()
except Exception:
pass # already logged inside save_to_file
# ------------------------------------------------------------------
# Generic override API
# ------------------------------------------------------------------
def get_overrides(self) -> Dict[str, Any]:
"""Return a shallow copy of all current overrides."""
return self._overrides.copy()
def set_override(self, key: str, value: Any) -> None:
"""Set a single override key and auto-save."""
self._overrides[key] = value
self._autosave()
def delete_override(self, key: str) -> None:
"""Remove a single override key (no-op if absent) and auto-save."""
self._overrides.pop(key, None)
self._autosave()
def clear_overrides(self) -> None:
"""Remove all override keys and auto-save."""
self._overrides = {}
self._autosave()
# ------------------------------------------------------------------
# Last-workflow-file tracking
# ------------------------------------------------------------------
def get_last_workflow_file(self) -> Optional[str]:
"""Return the last loaded workflow filename (for auto-load on restart)."""
return self._last_workflow_file
def set_last_workflow_file(self, filename: Optional[str]) -> None:
"""Record the last loaded workflow filename and auto-save."""
self._last_workflow_file = filename
self._autosave()
# ------------------------------------------------------------------
# Backward-compat: old get_changes / set_changes API
# ------------------------------------------------------------------
def get_changes(self) -> Dict[str, Any]:
"""
Alias for :meth:`get_overrides` retained for backward compatibility.
Returns a dict that always has ``prompt``, ``negative_prompt``,
``input_image``, and ``seed`` keys (value is ``None`` when unset)
so existing callers that rely on these specific keys still work.
"""
base: Dict[str, Any] = {
"prompt": None,
"negative_prompt": None,
"input_image": None,
"seed": None,
}
base.update(self._overrides)
return base
def set_changes(self, changes: Dict[str, Any], merge: bool = True) -> None:
"""
Set multiple overrides at once.
Parameters
----------
changes : dict
Key-value pairs to apply.
merge : bool
If True (default), merge with existing overrides.
If False, replace all overrides with these values (``None``
values are excluded).
"""
if merge:
for k, v in changes.items():
if v is not None:
self._overrides[k] = v
else:
self._overrides = {k: v for k, v in changes.items() if v is not None}
self._autosave()
# ------------------------------------------------------------------
# Convenience setters / getters for well-known keys
# ------------------------------------------------------------------
def set_prompt(self, prompt: str) -> None:
"""Set the positive prompt override."""
self.set_override("prompt", prompt)
def get_prompt(self) -> Optional[str]:
"""Return the positive prompt override, or None if not set."""
return self._overrides.get("prompt")
def set_negative_prompt(self, negative_prompt: str) -> None:
"""Set the negative prompt override."""
self.set_override("negative_prompt", negative_prompt)
def get_negative_prompt(self) -> Optional[str]:
"""Return the negative prompt override, or None if not set."""
return self._overrides.get("negative_prompt")
def set_input_image(self, input_image: str) -> None:
"""Set the input image override (filename)."""
self.set_override("input_image", input_image)
def get_input_image(self) -> Optional[str]:
"""Return the input image override, or None if not set."""
return self._overrides.get("input_image")
def set_seed(self, seed: int) -> None:
"""Pin a specific seed for deterministic generation."""
self.set_override("seed", seed)
def get_seed(self) -> Optional[int]:
"""Return the pinned seed, or None if randomising each run."""
return self._overrides.get("seed")
def clear_seed(self) -> None:
"""Clear the pinned seed, reverting to random generation."""
self.delete_override("seed")
def clear(self) -> None:
"""Reset all overrides (alias for :meth:`clear_overrides`)."""
self.clear_overrides()
def __repr__(self) -> str:
return (
f"WorkflowStateManager("
f"overrides={self._overrides!r}, "
f"last_workflow_file={self._last_workflow_file!r})"
)

0
workflows/.gitkeep Normal file
View File