Initial commit — ComfyUI Discord bot + web UI
Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI (React/TS/Vite/Tailwind), ComfyUI integration, generation history DB, preset manager, workflow inspector, and all supporting modules. Excluded from tracking: .env, invite_tokens.json, *.db (SQLite), current-workflow-changes.json, user_settings/, presets/, logs/, web-static/ (build output), frontend/node_modules/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
14
.gitattributes
vendored
Normal file
14
.gitattributes
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Normalize line endings to LF in the repository (CRLF on Windows checkout)
|
||||||
|
* text=auto eol=lf
|
||||||
|
|
||||||
|
# Binary files — no line-ending conversion
|
||||||
|
*.db binary
|
||||||
|
*.png binary
|
||||||
|
*.jpg binary
|
||||||
|
*.jpeg binary
|
||||||
|
*.gif binary
|
||||||
|
*.webp binary
|
||||||
|
*.mp4 binary
|
||||||
|
*.webm binary
|
||||||
|
*.zip binary
|
||||||
|
*.whl binary
|
||||||
81
.gitignore
vendored
Normal file
81
.gitignore
vendored
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# ── Python ────────────────────────────────────────────────────────────────────
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# ── Virtual environments ───────────────────────────────────────────────────────
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
# ── Secrets / environment ─────────────────────────────────────────────────────
|
||||||
|
# Contains DISCORD_BOT_TOKEN and other credentials — NEVER commit
|
||||||
|
.env
|
||||||
|
|
||||||
|
# Hashed invite tokens — auth credentials; regenerate via token_store.py CLI
|
||||||
|
invite_tokens.json
|
||||||
|
|
||||||
|
# ── SQLite databases ──────────────────────────────────────────────────────────
|
||||||
|
# generation_history.db — user generation records
|
||||||
|
# input_images.db — image BLOBs (can be hundreds of MB)
|
||||||
|
*.db
|
||||||
|
|
||||||
|
# ── Runtime / generated state ─────────────────────────────────────────────────
|
||||||
|
# Active workflow overrides (prompt, seed, etc.) — machine-local runtime state
|
||||||
|
current-workflow-changes.json
|
||||||
|
|
||||||
|
# Per-user persistent settings (created at runtime under user labels)
|
||||||
|
user_settings/
|
||||||
|
|
||||||
|
# User-created presets — runtime data; not project source
|
||||||
|
presets/
|
||||||
|
|
||||||
|
# NSSM / service log files
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# ── Frontend build artefacts ──────────────────────────────────────────────────
|
||||||
|
# Regenerate with: cd frontend && npm run build
|
||||||
|
web-static/
|
||||||
|
|
||||||
|
# npm dependencies — restored with: cd frontend && npm install
|
||||||
|
frontend/node_modules/
|
||||||
|
|
||||||
|
# Vite cache
|
||||||
|
frontend/.vite/
|
||||||
|
|
||||||
|
# ── IDE / editor ──────────────────────────────────────────────────────────────
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# ── OS ────────────────────────────────────────────────────────────────────────
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# ── Syncthing ─────────────────────────────────────────────────────────────────
|
||||||
|
.stfolder/
|
||||||
|
.stignore
|
||||||
|
|
||||||
|
# ── Claude Code project files ─────────────────────────────────────────────────
|
||||||
|
# Local conversation history and per-project Claude config — not shared
|
||||||
|
.claude/
|
||||||
194
CLAUDE.md
Normal file
194
CLAUDE.md
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
This is a Discord bot that integrates with ComfyUI to generate AI images and videos. Users interact via Discord commands, which queue generation requests that execute on a ComfyUI server.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### File Structure
|
||||||
|
|
||||||
|
The codebase is organized into focused modules for maintainability:
|
||||||
|
|
||||||
|
```
|
||||||
|
the-third-rev/
|
||||||
|
├── config.py # Configuration and constants
|
||||||
|
├── job_queue.py # Job queue system (SerialJobQueue)
|
||||||
|
├── workflow_manager.py # Workflow manipulation logic
|
||||||
|
├── workflow_state.py # Runtime workflow state management
|
||||||
|
├── discord_utils.py # Discord helpers and decorators
|
||||||
|
├── commands/ # Command handlers (organized by functionality)
|
||||||
|
│ ├── __init__.py # Command registration
|
||||||
|
│ ├── generation.py # generate, workflow-gen commands
|
||||||
|
│ ├── workflow.py # workflow-load command
|
||||||
|
│ ├── upload.py # upload command
|
||||||
|
│ ├── history.py # history, get-history commands
|
||||||
|
│ └── workflow_changes.py # get/set workflow changes commands
|
||||||
|
├── bot.py # Main bot entry point (~150 lines)
|
||||||
|
└── comfy_client.py # ComfyUI API client (~650 lines)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
- **bot.py**: Minimal Discord bot entry point. Loads configuration, creates dependencies, and registers commands. No command logic here.
|
||||||
|
- **comfy_client.py**: Async client wrapping ComfyUI's REST and WebSocket APIs. Dependencies (WorkflowManager, WorkflowStateManager) are injected via constructor.
|
||||||
|
- **config.py**: Centralized configuration with `BotConfig.from_env()` for loading environment variables and constants.
|
||||||
|
- **job_queue.py**: `SerialJobQueue` ensuring generation requests execute sequentially, preventing ComfyUI server overload.
|
||||||
|
- **workflow_manager.py**: `WorkflowManager` class handling workflow template storage and node manipulation (finding/replacing prompts, seeds, etc).
|
||||||
|
- **workflow_state.py**: `WorkflowStateManager` class managing runtime workflow changes (prompt, negative_prompt, input_image) in memory with optional file persistence.
|
||||||
|
- **discord_utils.py**: Reusable Discord utilities including `@require_comfy_client` decorator, argument parsing, and the `UploadView` component.
|
||||||
|
- **commands/**: Command handlers organized by functionality. Each module exports a `setup_*_commands(bot, config)` function.
|
||||||
|
|
||||||
|
### Key Architectural Patterns
|
||||||
|
|
||||||
|
1. **Dependency Injection**: ComfyClient receives WorkflowManager and WorkflowStateManager as constructor parameters, eliminating tight coupling to file-based state.
|
||||||
|
|
||||||
|
2. **Job Queue System**: All generation requests are queued through `SerialJobQueue` in job_queue.py. Jobs execute serially with a worker loop that catches and logs exceptions without crashing the bot.
|
||||||
|
|
||||||
|
3. **Workflow System**: The bot uses two modes:
|
||||||
|
- **Prompt mode**: Simple prompt + negative_prompt (requires workflow template with KSampler node)
|
||||||
|
- **Workflow mode**: Full workflow JSON with dynamic modifications from WorkflowStateManager
|
||||||
|
|
||||||
|
4. **Workflow Modification Flow**:
|
||||||
|
- Load workflow template via `bot.comfy.set_workflow()` or `bot.comfy.load_workflow_from_file()`
|
||||||
|
- Runtime changes (prompt, negative_prompt, input_image) stored in WorkflowStateManager
|
||||||
|
- At generation time, WorkflowManager methods locate nodes by class_type and title metadata, then inject values
|
||||||
|
- Seeds are randomized automatically via `workflow_manager.find_and_replace_seed()`
|
||||||
|
|
||||||
|
5. **Command Registration**: Commands are registered via `commands.register_all_commands(bot, config)` which calls individual `setup_*_commands()` functions from each command module.
|
||||||
|
|
||||||
|
6. **Configuration Management**: All configuration loaded via `BotConfig.from_env()` in config.py. Constants (command prefixes, error messages, limits) centralized in config.py.
|
||||||
|
|
||||||
|
7. **History Management**: ComfyClient maintains a bounded deque of recent generations (configurable via `history_limit`) for retrieval via `ttr!get-history`.
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
Required in `.env`:
|
||||||
|
- `DISCORD_BOT_TOKEN`: Discord bot authentication token
|
||||||
|
- `COMFY_SERVER`: ComfyUI server address (e.g., `localhost:8188` or `example.com:8188`)
|
||||||
|
|
||||||
|
Optional:
|
||||||
|
- `WORKFLOW_FILE`: Path to JSON workflow file to load at startup
|
||||||
|
- `COMFY_HISTORY_LIMIT`: Number of generations to keep in history (default: 10)
|
||||||
|
- `COMFY_OUTPUT_PATH`: Path to ComfyUI output directory (default: `C:\Users\ktrangia\Documents\ComfyUI\output`)
|
||||||
|
|
||||||
|
## Running the Bot
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python bot.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The bot will:
|
||||||
|
1. Load configuration from environment variables via `BotConfig.from_env()`
|
||||||
|
2. Create WorkflowStateManager and WorkflowManager instances
|
||||||
|
3. Initialize ComfyClient with injected dependencies
|
||||||
|
4. Load workflow from `WORKFLOW_FILE` if specified
|
||||||
|
5. Register all commands via `commands.register_all_commands()`
|
||||||
|
6. Start Discord bot and job queue
|
||||||
|
7. Listen for commands with prefix `ttr!`
|
||||||
|
|
||||||
|
## Development Commands
|
||||||
|
|
||||||
|
No build/test/lint commands exist. This is a standalone Python application.
|
||||||
|
|
||||||
|
To run: `python bot.py`
|
||||||
|
|
||||||
|
## Key Implementation Details
|
||||||
|
|
||||||
|
### ComfyUI Workflow Node Injection
|
||||||
|
|
||||||
|
When generating with workflows, WorkflowManager searches for specific node patterns:
|
||||||
|
- **Prompt**: Finds `CLIPTextEncode` nodes with `_meta.title` containing "Positive Prompt"
|
||||||
|
- **Negative Prompt**: Finds `CLIPTextEncode` nodes with `_meta.title` containing "Negative Prompt"
|
||||||
|
- **Input Image**: Finds `LoadImage` nodes and replaces the `image` input
|
||||||
|
- **Seeds**: Finds any node with `inputs.seed` or `inputs.noise_seed` and randomizes
|
||||||
|
|
||||||
|
This pattern-matching approach means workflows must follow naming conventions in their node titles for dynamic updates to work.
|
||||||
|
|
||||||
|
### Discord Command Pattern
|
||||||
|
|
||||||
|
Commands use a labelled parameter syntax: `ttr!generate prompt:<text> negative_prompt:<text>`
|
||||||
|
|
||||||
|
Parsing is handled by helpers in discord_utils.py (e.g., `parse_labeled_args()`). The bot splits on keyword markers (`prompt:`, `negative_prompt:`, `type:`, etc.) rather than traditional argparse. Case is preserved for prompts.
|
||||||
|
|
||||||
|
### Job Queue Mechanics
|
||||||
|
|
||||||
|
Jobs are dataclasses with `run: Callable[[], Awaitable[None]]` and a `label` for logging. The queue returns position on submit. Jobs capture their context (ctx, prompts) via lambda closures when submitted.
|
||||||
|
|
||||||
|
### Image/Video Output Handling
|
||||||
|
|
||||||
|
The `_general_generate` method in ComfyClient returns both images and videos. Videos are identified by file extension (mp4, webm, avi) in the history response. For videos, the bot reads the file from disk at the path specified by `COMFY_OUTPUT_PATH` rather than downloading via the API.
|
||||||
|
|
||||||
|
### Command Validation
|
||||||
|
|
||||||
|
The `@require_comfy_client` decorator (from discord_utils.py) validates that `bot.comfy` exists before executing commands. This eliminates repetitive validation code in every command handler.
|
||||||
|
|
||||||
|
### State Management
|
||||||
|
|
||||||
|
WorkflowStateManager maintains runtime workflow changes in memory with optional persistence to `current-workflow-changes.json`. The file is loaded on initialization if it exists, and saved automatically when changes are made.
|
||||||
|
|
||||||
|
## Configuration System
|
||||||
|
|
||||||
|
Configuration is managed via the `BotConfig` dataclass in config.py:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from config import BotConfig
|
||||||
|
|
||||||
|
# Load from environment
|
||||||
|
config = BotConfig.from_env()
|
||||||
|
|
||||||
|
# Access configuration
|
||||||
|
server = config.comfy_server
|
||||||
|
history_limit = config.comfy_history_limit
|
||||||
|
output_path = config.comfy_output_path
|
||||||
|
```
|
||||||
|
|
||||||
|
All constants (command prefixes, error messages, defaults) are defined in config.py and imported where needed.
|
||||||
|
|
||||||
|
## Adding New Commands
|
||||||
|
|
||||||
|
To add a new command:
|
||||||
|
|
||||||
|
1. Create a new module in `commands/` (e.g., `commands/my_feature.py`)
|
||||||
|
2. Define a `setup_my_feature_commands(bot, config=None)` function
|
||||||
|
3. Use `@bot.command(name="...")` decorators to define commands
|
||||||
|
4. Use `@require_comfy_client` decorator if command needs ComfyClient
|
||||||
|
5. Import and call your setup function in `commands/__init__.py`'s `register_all_commands()`
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# commands/my_feature.py
|
||||||
|
from discord.ext import commands
|
||||||
|
from discord_utils import require_comfy_client
|
||||||
|
|
||||||
|
def setup_my_feature_commands(bot):
|
||||||
|
@bot.command(name="my-command")
|
||||||
|
@require_comfy_client
|
||||||
|
async def my_command(ctx: commands.Context):
|
||||||
|
await ctx.reply("Hello from my command!")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
From imports:
|
||||||
|
- discord.py
|
||||||
|
- aiohttp
|
||||||
|
- websockets
|
||||||
|
- python-dotenv (optional, for .env loading)
|
||||||
|
|
||||||
|
No requirements.txt exists. Install manually: `pip install discord.py aiohttp websockets python-dotenv`
|
||||||
|
|
||||||
|
## Code Organization Principles
|
||||||
|
|
||||||
|
The refactored codebase follows these principles:
|
||||||
|
|
||||||
|
1. **Single Responsibility**: Each module has one clear purpose
|
||||||
|
2. **Dependency Injection**: Dependencies passed via constructor, not created internally
|
||||||
|
3. **Configuration Centralization**: All configuration in config.py
|
||||||
|
4. **Command Separation**: Commands grouped by functionality in separate modules
|
||||||
|
5. **No Magic Strings**: Constants defined once in config.py
|
||||||
|
6. **Type Safety**: Modern Python type hints throughout (dict[str, Any] instead of Dict)
|
||||||
|
7. **Logging**: Using logger methods instead of print() statements
|
||||||
780
DEVELOPMENT.md
Normal file
780
DEVELOPMENT.md
Normal file
@@ -0,0 +1,780 @@
|
|||||||
|
# Development Guide
|
||||||
|
|
||||||
|
This guide explains how to add new commands, features, and modules to the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Adding a New Command](#adding-a-new-command)
|
||||||
|
- [Adding a New Feature Module](#adding-a-new-feature-module)
|
||||||
|
- [Adding Configuration Options](#adding-configuration-options)
|
||||||
|
- [Working with Workflows](#working-with-workflows)
|
||||||
|
- [Best Practices](#best-practices)
|
||||||
|
- [Common Patterns](#common-patterns)
|
||||||
|
- [Testing Your Changes](#testing-your-changes)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Adding a New Command
|
||||||
|
|
||||||
|
Commands are organized in the `commands/` directory by functionality. Here's how to add a new command:
|
||||||
|
|
||||||
|
### Step 1: Choose the Right Module
|
||||||
|
|
||||||
|
Determine which existing command module your command belongs to:
|
||||||
|
|
||||||
|
- **generation.py** - Image/video generation commands
|
||||||
|
- **workflow.py** - Workflow template management
|
||||||
|
- **upload.py** - File upload commands
|
||||||
|
- **history.py** - History viewing and retrieval
|
||||||
|
- **workflow_changes.py** - Runtime workflow parameter management
|
||||||
|
|
||||||
|
If none fit, create a new module (see [Adding a New Feature Module](#adding-a-new-feature-module)).
|
||||||
|
|
||||||
|
### Step 2: Add Your Command Function
|
||||||
|
|
||||||
|
Edit the appropriate module in `commands/` and add your command to the `setup_*_commands()` function:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# commands/generation.py
|
||||||
|
|
||||||
|
def setup_generation_commands(bot, config):
|
||||||
|
# ... existing commands ...
|
||||||
|
|
||||||
|
@bot.command(name="my-new-command", aliases=["mnc", "my-cmd"])
|
||||||
|
@require_comfy_client # Use this decorator if you need bot.comfy
|
||||||
|
async def my_new_command(ctx: commands.Context, *, args: str = "") -> None:
|
||||||
|
"""
|
||||||
|
Brief description of what your command does.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!my-new-command [arguments]
|
||||||
|
|
||||||
|
Longer description with examples and details.
|
||||||
|
"""
|
||||||
|
# Parse arguments if needed
|
||||||
|
if not args:
|
||||||
|
await ctx.reply("Please provide arguments!", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Your command logic here
|
||||||
|
result = await bot.comfy.some_method(args)
|
||||||
|
|
||||||
|
# Send response
|
||||||
|
await ctx.reply(f"Success! Result: {result}", mention_author=False)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to execute my-new-command")
|
||||||
|
await ctx.reply(
|
||||||
|
f"An error occurred: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Import Required Dependencies
|
||||||
|
|
||||||
|
At the top of your command module, import what you need:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import logging
|
||||||
|
from discord.ext import commands
|
||||||
|
from discord_utils import require_comfy_client, parse_labeled_args
|
||||||
|
from config import ARG_PROMPT_KEY, ARG_TYPE_KEY
|
||||||
|
from job_queue import Job
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Test Your Command
|
||||||
|
|
||||||
|
Run the bot and test your command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python bot.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Then in Discord:
|
||||||
|
```
|
||||||
|
ttr!my-new-command test arguments
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Adding a New Feature Module
|
||||||
|
|
||||||
|
If your commands don't fit existing modules, create a new one:
|
||||||
|
|
||||||
|
### Step 1: Create the Module File
|
||||||
|
|
||||||
|
Create `commands/your_feature.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
commands/your_feature.py
|
||||||
|
========================
|
||||||
|
|
||||||
|
Description of what this module handles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from discord.ext import commands
|
||||||
|
from discord_utils import require_comfy_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_your_feature_commands(bot, config):
|
||||||
|
"""
|
||||||
|
Register your feature commands with the bot.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bot : commands.Bot
|
||||||
|
The Discord bot instance.
|
||||||
|
config : BotConfig
|
||||||
|
The bot configuration object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@bot.command(name="feature-command")
|
||||||
|
@require_comfy_client
|
||||||
|
async def feature_command(ctx: commands.Context, *, args: str = "") -> None:
|
||||||
|
"""Command description."""
|
||||||
|
await ctx.reply("Feature command executed!", mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(name="another-command", aliases=["ac"])
|
||||||
|
async def another_command(ctx: commands.Context) -> None:
|
||||||
|
"""Another command description."""
|
||||||
|
await ctx.reply("Another command!", mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Register in commands/__init__.py
|
||||||
|
|
||||||
|
Edit `commands/__init__.py` to import and register your module:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .generation import setup_generation_commands
|
||||||
|
from .workflow import setup_workflow_commands
|
||||||
|
from .upload import setup_upload_commands
|
||||||
|
from .history import setup_history_commands
|
||||||
|
from .workflow_changes import setup_workflow_changes_commands
|
||||||
|
from .your_feature import setup_your_feature_commands # ADD THIS
|
||||||
|
|
||||||
|
|
||||||
|
def register_all_commands(bot, config):
|
||||||
|
"""Register all bot commands."""
|
||||||
|
setup_generation_commands(bot, config)
|
||||||
|
setup_workflow_commands(bot)
|
||||||
|
setup_upload_commands(bot)
|
||||||
|
setup_history_commands(bot)
|
||||||
|
setup_workflow_changes_commands(bot)
|
||||||
|
setup_your_feature_commands(bot, config) # ADD THIS
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Update Documentation
|
||||||
|
|
||||||
|
Add your module to `CLAUDE.md`:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
### File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
commands/
|
||||||
|
├── __init__.py
|
||||||
|
├── generation.py
|
||||||
|
├── workflow.py
|
||||||
|
├── upload.py
|
||||||
|
├── history.py
|
||||||
|
├── workflow_changes.py
|
||||||
|
└── your_feature.py # Your new module
|
||||||
|
```
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Adding Configuration Options
|
||||||
|
|
||||||
|
Configuration is centralized in `config.py`. Here's how to add new options:
|
||||||
|
|
||||||
|
### Step 1: Add Constants (if needed)
|
||||||
|
|
||||||
|
Edit `config.py` and add constants in the appropriate section:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ========================================
|
||||||
|
# Your Feature Constants
|
||||||
|
# ========================================
|
||||||
|
|
||||||
|
MY_FEATURE_DEFAULT_VALUE = 42
|
||||||
|
"""Default value for my feature."""
|
||||||
|
|
||||||
|
MY_FEATURE_MAX_LIMIT = 100
|
||||||
|
"""Maximum limit for my feature."""
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Add to BotConfig (if environment variable)
|
||||||
|
|
||||||
|
If your config comes from environment variables, add it to `BotConfig`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class BotConfig:
|
||||||
|
"""Configuration container for the Discord ComfyUI bot."""
|
||||||
|
|
||||||
|
discord_bot_token: str
|
||||||
|
comfy_server: str
|
||||||
|
comfy_output_path: str
|
||||||
|
comfy_history_limit: int
|
||||||
|
workflow_file: Optional[str] = None
|
||||||
|
my_feature_enabled: bool = False # ADD THIS
|
||||||
|
my_feature_value: int = MY_FEATURE_DEFAULT_VALUE # ADD THIS
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Load in from_env()
|
||||||
|
|
||||||
|
Add loading logic in `BotConfig.from_env()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> BotConfig:
|
||||||
|
"""Create a BotConfig instance by loading from environment."""
|
||||||
|
# ... existing code ...
|
||||||
|
|
||||||
|
# Load your feature config
|
||||||
|
my_feature_enabled = os.getenv("MY_FEATURE_ENABLED", "false").lower() == "true"
|
||||||
|
|
||||||
|
try:
|
||||||
|
my_feature_value = int(os.getenv("MY_FEATURE_VALUE", str(MY_FEATURE_DEFAULT_VALUE)))
|
||||||
|
except ValueError:
|
||||||
|
my_feature_value = MY_FEATURE_DEFAULT_VALUE
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
# ... existing parameters ...
|
||||||
|
my_feature_enabled=my_feature_enabled,
|
||||||
|
my_feature_value=my_feature_value,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Use in Your Commands
|
||||||
|
|
||||||
|
Access config in your commands:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def setup_your_feature_commands(bot, config):
|
||||||
|
@bot.command(name="feature")
|
||||||
|
async def feature_command(ctx: commands.Context):
|
||||||
|
if not config.my_feature_enabled:
|
||||||
|
await ctx.reply("Feature is disabled!", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
value = config.my_feature_value
|
||||||
|
await ctx.reply(f"Feature value: {value}", mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 5: Document the Environment Variable
|
||||||
|
|
||||||
|
Update `CLAUDE.md` and add to `.env.example` (if you create one):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Feature Configuration
|
||||||
|
MY_FEATURE_ENABLED=true
|
||||||
|
MY_FEATURE_VALUE=42
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Working with Workflows
|
||||||
|
|
||||||
|
The bot has separate concerns for workflows:
|
||||||
|
|
||||||
|
- **WorkflowManager** (`workflow_manager.py`) - Template storage and node manipulation
|
||||||
|
- **WorkflowStateManager** (`workflow_state.py`) - Runtime state (prompt, negative_prompt, input_image)
|
||||||
|
- **ComfyClient** (`comfy_client.py`) - Uses both managers to generate images
|
||||||
|
|
||||||
|
### Adding New Workflow Node Types
|
||||||
|
|
||||||
|
If you need to manipulate new types of nodes in workflows:
|
||||||
|
|
||||||
|
#### Step 1: Add Method to WorkflowManager
|
||||||
|
|
||||||
|
Edit `workflow_manager.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def find_and_replace_my_node(
|
||||||
|
self, workflow: Dict[str, Any], my_value: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Find and replace my custom node type.
|
||||||
|
|
||||||
|
This searches for nodes of a specific class_type and updates their inputs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
workflow : Dict[str, Any]
|
||||||
|
The workflow definition to modify.
|
||||||
|
my_value : str
|
||||||
|
The value to inject.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, Any]
|
||||||
|
The modified workflow.
|
||||||
|
"""
|
||||||
|
for node_id, node in workflow.items():
|
||||||
|
if node.get("class_type") == "MyCustomNodeType" and node.get("inputs"):
|
||||||
|
# Check metadata for specific node identification
|
||||||
|
meta = node.get("_meta", {})
|
||||||
|
if "My Custom Node" in meta.get("title", ""):
|
||||||
|
workflow[node_id]["inputs"]["my_input"] = my_value
|
||||||
|
logger.debug("Replaced my_value in node %s", node_id)
|
||||||
|
|
||||||
|
return workflow
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 2: Add to apply_state_changes()
|
||||||
|
|
||||||
|
Update `apply_state_changes()` to include your new manipulation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def apply_state_changes(
|
||||||
|
self,
|
||||||
|
workflow: Dict[str, Any],
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
input_image: Optional[str] = None,
|
||||||
|
my_custom_value: Optional[str] = None, # ADD THIS
|
||||||
|
randomize_seed: bool = True,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Apply multiple state changes to a workflow in one pass."""
|
||||||
|
if randomize_seed:
|
||||||
|
workflow = self.find_and_replace_seed(workflow)
|
||||||
|
|
||||||
|
if prompt is not None:
|
||||||
|
workflow = self.find_and_replace_prompt(workflow, prompt)
|
||||||
|
|
||||||
|
if negative_prompt is not None:
|
||||||
|
workflow = self.find_and_replace_negative_prompt(workflow, negative_prompt)
|
||||||
|
|
||||||
|
if input_image is not None:
|
||||||
|
workflow = self.find_and_replace_input_image(workflow, input_image)
|
||||||
|
|
||||||
|
# ADD THIS
|
||||||
|
if my_custom_value is not None:
|
||||||
|
workflow = self.find_and_replace_my_node(workflow, my_custom_value)
|
||||||
|
|
||||||
|
return workflow
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 3: Add State to WorkflowStateManager
|
||||||
|
|
||||||
|
Edit `workflow_state.py` to track the new state:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def __init__(self, state_file: Optional[str] = None):
|
||||||
|
"""Initialize the workflow state manager."""
|
||||||
|
self._state: Dict[str, Any] = {
|
||||||
|
"prompt": None,
|
||||||
|
"negative_prompt": None,
|
||||||
|
"input_image": None,
|
||||||
|
"my_custom_value": None, # ADD THIS
|
||||||
|
}
|
||||||
|
# ... rest of init ...
|
||||||
|
|
||||||
|
def set_my_custom_value(self, value: str) -> None:
|
||||||
|
"""Set the custom value."""
|
||||||
|
self._state["my_custom_value"] = value
|
||||||
|
if self._state_file:
|
||||||
|
try:
|
||||||
|
self.save_to_file()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_my_custom_value(self) -> Optional[str]:
|
||||||
|
"""Get the custom value."""
|
||||||
|
return self._state.get("my_custom_value")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 4: Use in ComfyClient
|
||||||
|
|
||||||
|
The ComfyClient will automatically use your new state if you update `generate_image_with_workflow()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def generate_image_with_workflow(self) -> tuple[List[bytes], List[dict[str, Any]], str]:
|
||||||
|
# ... existing code ...
|
||||||
|
|
||||||
|
# Get current state changes
|
||||||
|
changes = self.state_manager.get_changes()
|
||||||
|
prompt = changes.get("prompt")
|
||||||
|
negative_prompt = changes.get("negative_prompt")
|
||||||
|
input_image = changes.get("input_image")
|
||||||
|
my_custom_value = changes.get("my_custom_value") # ADD THIS
|
||||||
|
|
||||||
|
# Apply changes using WorkflowManager
|
||||||
|
workflow = self.workflow_manager.apply_state_changes(
|
||||||
|
workflow,
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
input_image=input_image,
|
||||||
|
my_custom_value=my_custom_value, # ADD THIS
|
||||||
|
randomize_seed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ... rest of method ...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### Command Design
|
||||||
|
|
||||||
|
1. **Use descriptive names**: `ttr!generate` is better than `ttr!gen` (but provide aliases)
|
||||||
|
2. **Validate inputs early**: Check arguments before starting long operations
|
||||||
|
3. **Provide clear feedback**: Tell users what's happening and when it's done
|
||||||
|
4. **Handle errors gracefully**: Catch exceptions and show user-friendly messages
|
||||||
|
5. **Use decorators**: `@require_comfy_client` eliminates boilerplate
|
||||||
|
|
||||||
|
### Code Organization
|
||||||
|
|
||||||
|
1. **One responsibility per module**: Don't mix unrelated commands
|
||||||
|
2. **Keep functions small**: If a function is > 50 lines, consider splitting it
|
||||||
|
3. **Use type hints**: Help future developers understand your code
|
||||||
|
4. **Document with docstrings**: Explain what, why, and how
|
||||||
|
|
||||||
|
### Discord Best Practices
|
||||||
|
|
||||||
|
1. **Use `mention_author=False`**: Prevents spam from @mentions
|
||||||
|
2. **Use `delete_after=X`**: For temporary status messages
|
||||||
|
3. **Use `ephemeral=True`**: For interaction responses (buttons/modals)
|
||||||
|
4. **Limit file attachments**: Discord has a 4-file limit (use `MAX_IMAGES_PER_RESPONSE`)
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
|
||||||
|
1. **Use job queue for long operations**: Queue generation requests
|
||||||
|
2. **Use typing indicator**: `async with ctx.typing():` shows bot is working
|
||||||
|
3. **Batch operations**: Don't send 10 separate messages when 1 will do
|
||||||
|
4. **Close resources**: Always close aiohttp sessions, file handles
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Common Patterns
|
||||||
|
|
||||||
|
### Pattern 1: Labeled Argument Parsing
|
||||||
|
|
||||||
|
For commands with `key:value` syntax:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from discord_utils import parse_labeled_args
|
||||||
|
from config import ARG_PROMPT_KEY, ARG_TYPE_KEY
|
||||||
|
|
||||||
|
@bot.command(name="my-command")
|
||||||
|
async def my_command(ctx: commands.Context, *, args: str = ""):
|
||||||
|
# Parse labeled arguments
|
||||||
|
parsed = parse_labeled_args(args, [ARG_PROMPT_KEY, ARG_TYPE_KEY])
|
||||||
|
|
||||||
|
prompt = parsed.get("prompt") # None if not provided
|
||||||
|
image_type = parsed.get("type") or "input" # Default to "input"
|
||||||
|
|
||||||
|
if not prompt:
|
||||||
|
await ctx.reply("Please provide a prompt!", mention_author=False)
|
||||||
|
return
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 2: Queued Job Execution
|
||||||
|
|
||||||
|
For long-running operations:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from job_queue import Job
|
||||||
|
|
||||||
|
@bot.command(name="long-operation")
|
||||||
|
@require_comfy_client
|
||||||
|
async def long_operation(ctx: commands.Context, *, args: str = ""):
|
||||||
|
try:
|
||||||
|
# Define the job function
|
||||||
|
async def _run_job():
|
||||||
|
async with ctx.typing():
|
||||||
|
result = await bot.comfy.some_long_operation(args)
|
||||||
|
await ctx.reply(f"Done! Result: {result}", mention_author=False)
|
||||||
|
|
||||||
|
# Submit to queue
|
||||||
|
position = await bot.jobq.submit(
|
||||||
|
Job(
|
||||||
|
label=f"long-operation:{ctx.author.id}",
|
||||||
|
run=_run_job,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.reply(
|
||||||
|
f"Queued ✅ (position: {position})",
|
||||||
|
mention_author=False,
|
||||||
|
delete_after=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to queue long operation")
|
||||||
|
await ctx.reply(
|
||||||
|
f"An error occurred: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 3: File Attachments
|
||||||
|
|
||||||
|
For uploading files to ComfyUI:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@bot.command(name="upload-and-process")
|
||||||
|
@require_comfy_client
|
||||||
|
async def upload_and_process(ctx: commands.Context):
|
||||||
|
if not ctx.message.attachments:
|
||||||
|
await ctx.reply("Please attach a file!", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
for attachment in ctx.message.attachments:
|
||||||
|
try:
|
||||||
|
# Download attachment
|
||||||
|
data = await attachment.read()
|
||||||
|
|
||||||
|
# Upload to ComfyUI
|
||||||
|
result = await bot.comfy.upload_image(
|
||||||
|
data,
|
||||||
|
attachment.filename,
|
||||||
|
image_type="input",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the uploaded file
|
||||||
|
filename = result.get("name")
|
||||||
|
await ctx.reply(f"Uploaded: {filename}", mention_author=False)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to process attachment")
|
||||||
|
await ctx.reply(
|
||||||
|
f"Failed: {attachment.filename}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 4: Interactive UI (Buttons)
|
||||||
|
|
||||||
|
For adding buttons to messages:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from discord.ui import View, Button
|
||||||
|
import discord
|
||||||
|
|
||||||
|
class MyView(View):
|
||||||
|
def __init__(self, data: str):
|
||||||
|
super().__init__(timeout=None)
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
@discord.ui.button(label="Click Me", style=discord.ButtonStyle.primary)
|
||||||
|
async def button_callback(
|
||||||
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||||
|
):
|
||||||
|
# Handle button click
|
||||||
|
await interaction.response.send_message(
|
||||||
|
f"You clicked! Data: {self.data}",
|
||||||
|
ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@bot.command(name="interactive")
|
||||||
|
async def interactive(ctx: commands.Context):
|
||||||
|
view = MyView(data="example")
|
||||||
|
await ctx.reply("Click the button:", view=view, mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 5: Using Configuration
|
||||||
|
|
||||||
|
Access bot configuration in commands:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def setup_my_commands(bot, config):
|
||||||
|
@bot.command(name="check-config")
|
||||||
|
async def check_config(ctx: commands.Context):
|
||||||
|
# Access config values
|
||||||
|
server = config.comfy_server
|
||||||
|
output_path = config.comfy_output_path
|
||||||
|
|
||||||
|
await ctx.reply(
|
||||||
|
f"Server: {server}\nOutput: {output_path}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing Your Changes
|
||||||
|
|
||||||
|
### Manual Testing Checklist
|
||||||
|
|
||||||
|
1. **Start the bot**: `python bot.py`
|
||||||
|
- Verify no import errors
|
||||||
|
- Check configuration loads correctly
|
||||||
|
- Confirm all commands register
|
||||||
|
|
||||||
|
2. **Test basic functionality**:
|
||||||
|
```
|
||||||
|
ttr!test
|
||||||
|
ttr!help
|
||||||
|
ttr!your-new-command
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Test error handling**:
|
||||||
|
- Run command with missing arguments
|
||||||
|
- Run command with invalid arguments
|
||||||
|
- Test when ComfyUI is unavailable (if applicable)
|
||||||
|
|
||||||
|
4. **Test edge cases**:
|
||||||
|
- Very long inputs
|
||||||
|
- Special characters
|
||||||
|
- Concurrent command execution
|
||||||
|
- Commands while queue is full
|
||||||
|
|
||||||
|
### Syntax Validation
|
||||||
|
|
||||||
|
Check for syntax errors without running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m py_compile bot.py
|
||||||
|
python -m py_compile commands/your_feature.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check Imports
|
||||||
|
|
||||||
|
Verify all imports work:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "from commands.your_feature import setup_your_feature_commands; print('OK')"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
Follow these conventions:
|
||||||
|
|
||||||
|
- **Indentation**: 4 spaces (no tabs)
|
||||||
|
- **Line length**: Max 100 characters (documentation can be longer)
|
||||||
|
- **Docstrings**: Use Google style or NumPy style (match existing code)
|
||||||
|
- **Imports**: Group stdlib, third-party, local (separated by blank lines)
|
||||||
|
- **Type hints**: Use modern syntax (`dict[str, Any]` not `Dict[str, Any]`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Example: Complete Feature Addition
|
||||||
|
|
||||||
|
Here's a complete example adding a "status" command:
|
||||||
|
|
||||||
|
### 1. Create commands/status.py
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
commands/status.py
|
||||||
|
==================
|
||||||
|
|
||||||
|
Bot status and diagnostics commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from discord.ext import commands
|
||||||
|
from discord_utils import require_comfy_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_status_commands(bot, config):
|
||||||
|
"""Register status commands with the bot."""
|
||||||
|
|
||||||
|
@bot.command(name="status", aliases=["s", "stat"])
|
||||||
|
async def status_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Show bot status and queue information.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!status
|
||||||
|
|
||||||
|
Displays:
|
||||||
|
- Bot connection status
|
||||||
|
- ComfyUI connection status
|
||||||
|
- Current queue size
|
||||||
|
- Configuration info
|
||||||
|
"""
|
||||||
|
# Check bot status
|
||||||
|
latency_ms = round(bot.latency * 1000)
|
||||||
|
|
||||||
|
# Check queue
|
||||||
|
if hasattr(bot, "jobq"):
|
||||||
|
queue_size = await bot.jobq.get_queue_size()
|
||||||
|
else:
|
||||||
|
queue_size = 0
|
||||||
|
|
||||||
|
# Check ComfyUI
|
||||||
|
comfy_status = "✅ Connected" if hasattr(bot, "comfy") else "❌ Not configured"
|
||||||
|
|
||||||
|
# Build status message
|
||||||
|
status_msg = [
|
||||||
|
"**Bot Status**",
|
||||||
|
f"• Latency: {latency_ms}ms",
|
||||||
|
f"• Queue size: {queue_size}",
|
||||||
|
f"• ComfyUI: {comfy_status}",
|
||||||
|
f"• Server: {config.comfy_server}",
|
||||||
|
]
|
||||||
|
|
||||||
|
await ctx.reply("\n".join(status_msg), mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(name="ping")
|
||||||
|
async def ping_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Check bot responsiveness.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!ping
|
||||||
|
"""
|
||||||
|
latency_ms = round(bot.latency * 1000)
|
||||||
|
await ctx.reply(f"🏓 Pong! Latency: {latency_ms}ms", mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Register in commands/__init__.py
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .status import setup_status_commands
|
||||||
|
|
||||||
|
def register_all_commands(bot, config):
|
||||||
|
# ... existing registrations ...
|
||||||
|
setup_status_commands(bot, config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Update CLAUDE.md
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
- **commands/status.py** - Bot status and diagnostics commands
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Test
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python bot.py
|
||||||
|
```
|
||||||
|
|
||||||
|
In Discord:
|
||||||
|
```
|
||||||
|
ttr!status
|
||||||
|
ttr!ping
|
||||||
|
ttr!s (alias test)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Getting Help
|
||||||
|
|
||||||
|
If you're stuck:
|
||||||
|
|
||||||
|
1. **Check CLAUDE.md**: Architecture and patterns documented there
|
||||||
|
2. **Read existing commands**: See how similar features are implemented
|
||||||
|
3. **Check logs**: Run bot and check console output for errors
|
||||||
|
4. **Test incrementally**: Add small pieces and test frequently
|
||||||
|
|
||||||
|
Remember: The refactored architecture makes adding features straightforward. Follow the patterns, and your code will fit right in! 🚀
|
||||||
196
QUICK_START.md
Normal file
196
QUICK_START.md
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
# Quick Start Guide
|
||||||
|
|
||||||
|
Quick reference for common development tasks.
|
||||||
|
|
||||||
|
## Add a Simple Command
|
||||||
|
|
||||||
|
```python
|
||||||
|
# commands/your_module.py
|
||||||
|
|
||||||
|
def setup_your_commands(bot, config):
|
||||||
|
@bot.command(name="hello")
|
||||||
|
async def hello(ctx):
|
||||||
|
await ctx.reply("Hello!", mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Register it:
|
||||||
|
```python
|
||||||
|
# commands/__init__.py
|
||||||
|
from .your_module import setup_your_commands
|
||||||
|
|
||||||
|
def register_all_commands(bot, config):
|
||||||
|
# ... existing ...
|
||||||
|
setup_your_commands(bot, config)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Add a Command That Uses ComfyUI
|
||||||
|
|
||||||
|
```python
|
||||||
|
from discord_utils import require_comfy_client
|
||||||
|
|
||||||
|
@bot.command(name="my-cmd")
|
||||||
|
@require_comfy_client # Validates bot.comfy exists
|
||||||
|
async def my_cmd(ctx):
|
||||||
|
result = await bot.comfy.some_method()
|
||||||
|
await ctx.reply(f"Result: {result}", mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Add a Long-Running Command
|
||||||
|
|
||||||
|
```python
|
||||||
|
from job_queue import Job
|
||||||
|
|
||||||
|
@bot.command(name="generate")
|
||||||
|
@require_comfy_client
|
||||||
|
async def generate(ctx, *, args: str = ""):
|
||||||
|
async def _run():
|
||||||
|
async with ctx.typing():
|
||||||
|
result = await bot.comfy.generate_image(args)
|
||||||
|
await ctx.reply(f"Done! {result}", mention_author=False)
|
||||||
|
|
||||||
|
pos = await bot.jobq.submit(Job(label="generate", run=_run))
|
||||||
|
await ctx.reply(f"Queued ✅ (position: {pos})", mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Add Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
# config.py
|
||||||
|
|
||||||
|
MY_FEATURE_ENABLED = True
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BotConfig:
|
||||||
|
# ... existing fields ...
|
||||||
|
my_feature_enabled: bool = MY_FEATURE_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> BotConfig:
|
||||||
|
# ... existing code ...
|
||||||
|
my_feature = os.getenv("MY_FEATURE_ENABLED", "true").lower() == "true"
|
||||||
|
return cls(
|
||||||
|
# ... existing params ...
|
||||||
|
my_feature_enabled=my_feature
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Use in commands:
|
||||||
|
```python
|
||||||
|
def setup_my_commands(bot, config):
|
||||||
|
@bot.command(name="feature")
|
||||||
|
async def feature(ctx):
|
||||||
|
if config.my_feature_enabled:
|
||||||
|
await ctx.reply("Enabled!", mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Parse Command Arguments
|
||||||
|
|
||||||
|
```python
|
||||||
|
from discord_utils import parse_labeled_args
|
||||||
|
from config import ARG_PROMPT_KEY, ARG_TYPE_KEY
|
||||||
|
|
||||||
|
@bot.command(name="cmd")
|
||||||
|
async def cmd(ctx, *, args: str = ""):
|
||||||
|
# Parse "prompt:text type:value" format
|
||||||
|
parsed = parse_labeled_args(args, [ARG_PROMPT_KEY, ARG_TYPE_KEY])
|
||||||
|
|
||||||
|
prompt = parsed.get("prompt")
|
||||||
|
img_type = parsed.get("type") or "input" # Default
|
||||||
|
|
||||||
|
if not prompt:
|
||||||
|
await ctx.reply("Missing prompt!", mention_author=False)
|
||||||
|
return
|
||||||
|
```
|
||||||
|
|
||||||
|
## Handle File Uploads
|
||||||
|
|
||||||
|
```python
|
||||||
|
@bot.command(name="upload")
|
||||||
|
async def upload(ctx):
|
||||||
|
if not ctx.message.attachments:
|
||||||
|
await ctx.reply("Attach a file!", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
for attachment in ctx.message.attachments:
|
||||||
|
data = await attachment.read()
|
||||||
|
# Process data...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Access Bot State
|
||||||
|
|
||||||
|
```python
|
||||||
|
@bot.command(name="info")
|
||||||
|
async def info(ctx):
|
||||||
|
# Queue size
|
||||||
|
queue_size = await bot.jobq.get_queue_size()
|
||||||
|
|
||||||
|
# Config
|
||||||
|
server = bot.config.comfy_server
|
||||||
|
|
||||||
|
# Last generation
|
||||||
|
last_id = bot.comfy.last_prompt_id
|
||||||
|
|
||||||
|
await ctx.reply(
|
||||||
|
f"Queue: {queue_size}, Server: {server}, Last: {last_id}",
|
||||||
|
mention_author=False
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Add Buttons
|
||||||
|
|
||||||
|
```python
|
||||||
|
from discord.ui import View, Button
|
||||||
|
import discord
|
||||||
|
|
||||||
|
class MyView(View):
|
||||||
|
@discord.ui.button(label="Click", style=discord.ButtonStyle.primary)
|
||||||
|
async def button_callback(self, interaction, button):
|
||||||
|
await interaction.response.send_message("Clicked!", ephemeral=True)
|
||||||
|
|
||||||
|
@bot.command(name="interactive")
|
||||||
|
async def interactive(ctx):
|
||||||
|
await ctx.reply("Press button:", view=MyView(), mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Imports
|
||||||
|
|
||||||
|
```python
|
||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
from discord.ext import commands
|
||||||
|
from discord_utils import require_comfy_client
|
||||||
|
from config import ARG_PROMPT_KEY, ARG_TYPE_KEY
|
||||||
|
from job_queue import Job
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Your Changes
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Syntax check
|
||||||
|
python -m py_compile commands/your_module.py
|
||||||
|
|
||||||
|
# Run bot
|
||||||
|
python bot.py
|
||||||
|
|
||||||
|
# Test in Discord
|
||||||
|
ttr!your-command
|
||||||
|
```
|
||||||
|
|
||||||
|
## File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
commands/
|
||||||
|
├── __init__.py # Register all commands here
|
||||||
|
├── generation.py # Generation commands
|
||||||
|
├── workflow.py # Workflow management
|
||||||
|
├── upload.py # File uploads
|
||||||
|
├── history.py # History retrieval
|
||||||
|
├── workflow_changes.py # State management
|
||||||
|
└── your_module.py # Your new module
|
||||||
|
```
|
||||||
|
|
||||||
|
## Need More Details?
|
||||||
|
|
||||||
|
See `DEVELOPMENT.md` for comprehensive guide with examples, patterns, and best practices.
|
||||||
345
README.md
Normal file
345
README.md
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
# Discord ComfyUI Bot
|
||||||
|
|
||||||
|
A Discord bot that integrates with ComfyUI to generate AI images and videos through Discord commands.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- 🎨 **Image Generation** - Generate images using simple prompts or complex workflows
|
||||||
|
- 🎬 **Video Generation** - Support for video output workflows
|
||||||
|
- 📝 **Workflow Management** - Load, modify, and execute ComfyUI workflows
|
||||||
|
- 📤 **Image Upload** - Upload reference images directly through Discord
|
||||||
|
- 📊 **Generation History** - Track and retrieve past generations
|
||||||
|
- ⚙️ **Runtime Workflow Modification** - Change prompts, negative prompts, and input images on the fly
|
||||||
|
- 🔄 **Job Queue System** - Sequential execution prevents server overload
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.9+
|
||||||
|
- Discord Bot Token ([create one here](https://discord.com/developers/applications))
|
||||||
|
- ComfyUI Server running and accessible
|
||||||
|
- Required packages: `discord.py`, `aiohttp`, `websockets`, `python-dotenv`
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
1. **Clone or download this repository**
|
||||||
|
|
||||||
|
2. **Install dependencies**:
|
||||||
|
```bash
|
||||||
|
pip install discord.py aiohttp websockets python-dotenv
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Create `.env` file** with your credentials:
|
||||||
|
```bash
|
||||||
|
DISCORD_BOT_TOKEN=your_discord_bot_token_here
|
||||||
|
COMFY_SERVER=localhost:8188
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Run the bot**:
|
||||||
|
```bash
|
||||||
|
python bot.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Create a `.env` file in the project root:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Required
|
||||||
|
DISCORD_BOT_TOKEN=your_discord_bot_token
|
||||||
|
COMFY_SERVER=localhost:8188
|
||||||
|
|
||||||
|
# Optional
|
||||||
|
WORKFLOW_FILE=wan2.2-fast.json
|
||||||
|
COMFY_HISTORY_LIMIT=10
|
||||||
|
COMFY_OUTPUT_PATH=C:\Users\YourName\Documents\ComfyUI\output
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Options
|
||||||
|
|
||||||
|
| Variable | Required | Default | Description |
|
||||||
|
|----------|----------|---------|-------------|
|
||||||
|
| `DISCORD_BOT_TOKEN` | ✅ Yes | - | Discord bot authentication token |
|
||||||
|
| `COMFY_SERVER` | ✅ Yes | - | ComfyUI server address (host:port) |
|
||||||
|
| `WORKFLOW_FILE` | ❌ No | - | Path to workflow JSON to load at startup |
|
||||||
|
| `COMFY_HISTORY_LIMIT` | ❌ No | `10` | Number of generations to keep in history |
|
||||||
|
| `COMFY_OUTPUT_PATH` | ❌ No | `C:\Users\...\ComfyUI\output` | Path to ComfyUI output directory |
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
All commands use the `ttr!` prefix.
|
||||||
|
|
||||||
|
### Basic Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Test if bot is working
|
||||||
|
ttr!test
|
||||||
|
|
||||||
|
# Generate an image with a prompt
|
||||||
|
ttr!generate prompt:a beautiful sunset over mountains
|
||||||
|
|
||||||
|
# Generate with negative prompt
|
||||||
|
ttr!generate prompt:a cat negative_prompt:blurry, low quality
|
||||||
|
|
||||||
|
# Execute loaded workflow
|
||||||
|
ttr!workflow-gen
|
||||||
|
|
||||||
|
# Queue multiple workflow runs
|
||||||
|
ttr!workflow-gen queue:5
|
||||||
|
```
|
||||||
|
|
||||||
|
### Workflow Management
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Load a workflow from file
|
||||||
|
ttr!workflow-load path/to/workflow.json
|
||||||
|
|
||||||
|
# Or attach a JSON file to the message:
|
||||||
|
ttr!workflow-load
|
||||||
|
[Attach: my_workflow.json]
|
||||||
|
|
||||||
|
# View current workflow changes
|
||||||
|
ttr!get-current-workflow-changes type:all
|
||||||
|
|
||||||
|
# Set workflow parameters
|
||||||
|
ttr!set-current-workflow-changes type:prompt A new prompt
|
||||||
|
ttr!set-current-workflow-changes type:negative_prompt blurry
|
||||||
|
ttr!set-current-workflow-changes type:input_image input/image.png
|
||||||
|
```
|
||||||
|
|
||||||
|
### Image Upload
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Upload images to ComfyUI
|
||||||
|
ttr!upload
|
||||||
|
[Attach: image1.png, image2.png]
|
||||||
|
|
||||||
|
# Upload to specific folder
|
||||||
|
ttr!upload type:temp
|
||||||
|
[Attach: reference.png]
|
||||||
|
```
|
||||||
|
|
||||||
|
### History
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# View recent generations
|
||||||
|
ttr!history
|
||||||
|
|
||||||
|
# Retrieve images from a past generation
|
||||||
|
ttr!get-history <prompt_id>
|
||||||
|
ttr!get-history 1 # By index
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command Aliases
|
||||||
|
|
||||||
|
Many commands have shorter aliases:
|
||||||
|
|
||||||
|
- `ttr!generate` → `ttr!gen`
|
||||||
|
- `ttr!workflow-gen` → `ttr!wfg`
|
||||||
|
- `ttr!workflow-load` → `ttr!wfl`
|
||||||
|
- `ttr!get-history` → `ttr!gh`
|
||||||
|
- `ttr!get-current-workflow-changes` → `ttr!gcwc`
|
||||||
|
- `ttr!set-current-workflow-changes` → `ttr!scwc`
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
The bot is organized into focused, maintainable modules:
|
||||||
|
|
||||||
|
```
|
||||||
|
the-third-rev/
|
||||||
|
├── config.py # Configuration and constants
|
||||||
|
├── job_queue.py # Job queue system
|
||||||
|
├── workflow_manager.py # Workflow manipulation
|
||||||
|
├── workflow_state.py # Runtime state management
|
||||||
|
├── discord_utils.py # Discord utilities
|
||||||
|
├── bot.py # Main entry point (~150 lines)
|
||||||
|
├── comfy_client.py # ComfyUI API client (~650 lines)
|
||||||
|
└── commands/ # Command handlers
|
||||||
|
├── generation.py # Image/video generation
|
||||||
|
├── workflow.py # Workflow management
|
||||||
|
├── upload.py # File uploads
|
||||||
|
├── history.py # History retrieval
|
||||||
|
└── workflow_changes.py # State management
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Design Principles
|
||||||
|
|
||||||
|
- **Dependency Injection** - Dependencies passed via constructor
|
||||||
|
- **Single Responsibility** - Each module has one clear purpose
|
||||||
|
- **Configuration Centralization** - All config in `config.py`
|
||||||
|
- **Command Separation** - Commands grouped by functionality
|
||||||
|
- **Type Safety** - Modern Python type hints throughout
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Adding a New Command
|
||||||
|
|
||||||
|
See `QUICK_START.md` for quick examples or `DEVELOPMENT.md` for comprehensive guide.
|
||||||
|
|
||||||
|
Basic example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# commands/your_module.py
|
||||||
|
|
||||||
|
def setup_your_commands(bot, config):
|
||||||
|
@bot.command(name="hello")
|
||||||
|
async def hello(ctx):
|
||||||
|
await ctx.reply("Hello!", mention_author=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Register in `commands/__init__.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .your_module import setup_your_commands
|
||||||
|
|
||||||
|
def register_all_commands(bot, config):
|
||||||
|
# ... existing ...
|
||||||
|
setup_your_commands(bot, config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
- **README.md** (this file) - Project overview and setup
|
||||||
|
- **QUICK_START.md** - Quick reference for common tasks
|
||||||
|
- **DEVELOPMENT.md** - Comprehensive development guide
|
||||||
|
- **CLAUDE.md** - Architecture documentation for Claude Code
|
||||||
|
|
||||||
|
## Workflow System
|
||||||
|
|
||||||
|
The bot supports two generation modes:
|
||||||
|
|
||||||
|
### 1. Prompt Mode (Simple)
|
||||||
|
|
||||||
|
Uses a workflow template with a KSampler node:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ttr!generate prompt:a cat negative_prompt:blurry
|
||||||
|
```
|
||||||
|
|
||||||
|
The bot automatically finds and replaces:
|
||||||
|
- Positive prompt in CLIPTextEncode node (title: "Positive Prompt")
|
||||||
|
- Negative prompt in CLIPTextEncode node (title: "Negative Prompt")
|
||||||
|
- Seed values (randomized each run)
|
||||||
|
|
||||||
|
### 2. Workflow Mode (Advanced)
|
||||||
|
|
||||||
|
Execute full workflow with runtime modifications:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set workflow parameters
|
||||||
|
ttr!set-current-workflow-changes type:prompt A beautiful landscape
|
||||||
|
ttr!set-current-workflow-changes type:input_image input/reference.png
|
||||||
|
|
||||||
|
# Execute workflow
|
||||||
|
ttr!workflow-gen
|
||||||
|
```
|
||||||
|
|
||||||
|
The bot:
|
||||||
|
1. Loads the workflow template
|
||||||
|
2. Applies runtime changes from WorkflowStateManager
|
||||||
|
3. Randomizes seeds
|
||||||
|
4. Executes on ComfyUI server
|
||||||
|
5. Returns images/videos
|
||||||
|
|
||||||
|
### Node Naming Conventions
|
||||||
|
|
||||||
|
For workflows to work with dynamic updates, nodes must follow naming conventions:
|
||||||
|
|
||||||
|
- **Positive Prompt**: CLIPTextEncode node with title containing "Positive Prompt"
|
||||||
|
- **Negative Prompt**: CLIPTextEncode node with title containing "Negative Prompt"
|
||||||
|
- **Input Image**: LoadImage node (any title)
|
||||||
|
- **Seeds**: Any node with `inputs.seed` or `inputs.noise_seed`
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Bot won't start
|
||||||
|
|
||||||
|
**Issue**: `AttributeError: module 'queue' has no attribute 'SimpleQueue'`
|
||||||
|
|
||||||
|
**Solution**: This was fixed by renaming `queue.py` to `job_queue.py`. Make sure you're using the latest version.
|
||||||
|
|
||||||
|
### ComfyUI connection issues
|
||||||
|
|
||||||
|
**Issue**: `ComfyUI client is not configured`
|
||||||
|
|
||||||
|
**Solution**:
|
||||||
|
1. Check `.env` file has `DISCORD_BOT_TOKEN` and `COMFY_SERVER`
|
||||||
|
2. Verify ComfyUI server is running
|
||||||
|
3. Test connection: `curl http://localhost:8188`
|
||||||
|
|
||||||
|
### Commands not responding
|
||||||
|
|
||||||
|
**Issue**: Bot online but commands don't work
|
||||||
|
|
||||||
|
**Solution**:
|
||||||
|
1. Check bot has Message Content Intent enabled in Discord Developer Portal
|
||||||
|
2. Verify bot has permissions in Discord server
|
||||||
|
3. Check console logs for errors
|
||||||
|
|
||||||
|
### Video files not found
|
||||||
|
|
||||||
|
**Issue**: `Failed to read video file`
|
||||||
|
|
||||||
|
**Solution**:
|
||||||
|
1. Set `COMFY_OUTPUT_PATH` in `.env` to your ComfyUI output directory
|
||||||
|
2. Check path uses correct format for your OS
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Batch Generation
|
||||||
|
|
||||||
|
Queue multiple workflow runs:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ttr!workflow-gen queue:10
|
||||||
|
```
|
||||||
|
|
||||||
|
Each run uses randomized seeds for variation.
|
||||||
|
|
||||||
|
### Custom Workflows
|
||||||
|
|
||||||
|
1. Design workflow in ComfyUI
|
||||||
|
2. Export as API format (Save → API Format)
|
||||||
|
3. Load in bot:
|
||||||
|
```bash
|
||||||
|
ttr!workflow-load path/to/workflow.json
|
||||||
|
```
|
||||||
|
4. Modify at runtime:
|
||||||
|
```bash
|
||||||
|
ttr!set-current-workflow-changes type:prompt My prompt
|
||||||
|
ttr!workflow-gen
|
||||||
|
```
|
||||||
|
|
||||||
|
### State Persistence
|
||||||
|
|
||||||
|
Workflow changes are automatically saved to `current-workflow-changes.json` and persist across bot restarts.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Please:
|
||||||
|
|
||||||
|
1. Read `DEVELOPMENT.md` for coding guidelines
|
||||||
|
2. Follow existing code style and patterns
|
||||||
|
3. Test your changes thoroughly
|
||||||
|
4. Update documentation as needed
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
[Your License Here]
|
||||||
|
|
||||||
|
## Support
|
||||||
|
|
||||||
|
For issues or questions:
|
||||||
|
- Check the troubleshooting section above
|
||||||
|
- Review `DEVELOPMENT.md` for implementation details
|
||||||
|
- Check ComfyUI documentation for workflow issues
|
||||||
|
- Open an issue on GitHub
|
||||||
|
|
||||||
|
## Credits
|
||||||
|
|
||||||
|
Built with:
|
||||||
|
- [discord.py](https://github.com/Rapptz/discord.py) - Discord API wrapper
|
||||||
|
- [ComfyUI](https://github.com/comfyanonymous/ComfyUI) - Stable Diffusion GUI
|
||||||
|
- [aiohttp](https://github.com/aio-libs/aiohttp) - Async HTTP client
|
||||||
|
- [websockets](https://github.com/python-websockets/websockets) - WebSocket implementation
|
||||||
156
backfill_image_data.py
Normal file
156
backfill_image_data.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
backfill_image_data.py
|
||||||
|
======================
|
||||||
|
|
||||||
|
One-shot script to download image bytes from Discord and store them in
|
||||||
|
input_images.db for rows that currently have image_data = NULL.
|
||||||
|
|
||||||
|
These rows were created before the BLOB-storage migration, so their bytes
|
||||||
|
were never persisted. The script re-fetches each bot-reply message from
|
||||||
|
Discord and writes the raw attachment bytes back into the DB.
|
||||||
|
|
||||||
|
Rows with bot_reply_id = 0 (web uploads that pre-date the migration) have
|
||||||
|
no Discord source and are skipped — re-upload them via the web UI to
|
||||||
|
backfill.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
python backfill_image_data.py
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
DISCORD_BOT_TOKEN in .env (same token the bot uses)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from config import BotConfig
|
||||||
|
from input_image_db import DB_PATH
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_null_rows() -> list[dict]:
|
||||||
|
"""Return all rows that are missing image_data."""
|
||||||
|
conn = sqlite3.connect(str(DB_PATH))
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, bot_reply_id, channel_id, filename"
|
||||||
|
" FROM input_images WHERE image_data IS NULL"
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def _save_image_data(row_id: int, data: bytes) -> None:
|
||||||
|
conn = sqlite3.connect(str(DB_PATH))
|
||||||
|
conn.execute("UPDATE input_images SET image_data = ? WHERE id = ?", (data, row_id))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def _backfill(client: discord.Client) -> None:
|
||||||
|
rows = _load_null_rows()
|
||||||
|
|
||||||
|
discord_rows = [r for r in rows if r["bot_reply_id"] != 0]
|
||||||
|
web_rows = [r for r in rows if r["bot_reply_id"] == 0]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Rows missing image_data: %d total (%d from Discord, %d web-uploads skipped)",
|
||||||
|
len(rows), len(discord_rows), len(web_rows),
|
||||||
|
)
|
||||||
|
|
||||||
|
if web_rows:
|
||||||
|
logger.info(
|
||||||
|
"Skipped row IDs (no Discord source — re-upload via web UI): %s",
|
||||||
|
[r["id"] for r in web_rows],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not discord_rows:
|
||||||
|
logger.info("Nothing to fetch. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
ok = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
for row in discord_rows:
|
||||||
|
row_id = row["id"]
|
||||||
|
ch_id = row["channel_id"]
|
||||||
|
msg_id = row["bot_reply_id"]
|
||||||
|
filename = row["filename"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
channel = client.get_channel(ch_id) or await client.fetch_channel(ch_id)
|
||||||
|
message = await channel.fetch_message(msg_id)
|
||||||
|
|
||||||
|
attachment = next(
|
||||||
|
(a for a in message.attachments if a.filename == filename), None
|
||||||
|
)
|
||||||
|
if attachment is None:
|
||||||
|
logger.warning(
|
||||||
|
"Row %d: attachment '%s' not found on message %d — skipping",
|
||||||
|
row_id, filename, msg_id,
|
||||||
|
)
|
||||||
|
failed += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = await attachment.read()
|
||||||
|
_save_image_data(row_id, data)
|
||||||
|
logger.info("Row %d: saved '%s' (%d bytes)", row_id, filename, len(data))
|
||||||
|
ok += 1
|
||||||
|
|
||||||
|
except discord.NotFound:
|
||||||
|
logger.warning("Row %d: message %d not found (deleted?) — skipping", row_id, msg_id)
|
||||||
|
failed += 1
|
||||||
|
except discord.Forbidden:
|
||||||
|
logger.warning("Row %d: no access to channel %d — skipping", row_id, ch_id)
|
||||||
|
failed += 1
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Row %d: unexpected error — %s", row_id, exc)
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Done. %d saved, %d failed/skipped, %d web-upload rows not touched.",
|
||||||
|
ok, failed, len(web_rows),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _main(token: str) -> None:
|
||||||
|
intents = discord.Intents.none() # no gateway events needed beyond connect
|
||||||
|
client = discord.Client(intents=intents)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
logger.info("Logged in as %s", client.user)
|
||||||
|
try:
|
||||||
|
await _backfill(client)
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
await client.start(token)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
try:
|
||||||
|
config = BotConfig.from_env()
|
||||||
|
except RuntimeError as exc:
|
||||||
|
logger.error("Config error: %s", exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
asyncio.run(_main(config.discord_bot_token))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
218
bot.py
Normal file
218
bot.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""
|
||||||
|
bot.py
|
||||||
|
======
|
||||||
|
|
||||||
|
Discord bot entry point. In WEB_ENABLED mode, also starts a FastAPI/Uvicorn
|
||||||
|
web server in the same asyncio event loop via asyncio.gather.
|
||||||
|
|
||||||
|
Jobs are submitted directly to ComfyUI — no internal SerialJobQueue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from comfy_client import ComfyClient
|
||||||
|
from config import BotConfig, COMMAND_PREFIX
|
||||||
|
import generation_db
|
||||||
|
from input_image_db import init_db, get_all_images
|
||||||
|
from status_monitor import StatusMonitor
|
||||||
|
from workflow_manager import WorkflowManager
|
||||||
|
from workflow_state import WorkflowStateManager
|
||||||
|
from commands import register_all_commands, CustomHelpCommand
|
||||||
|
from commands.input_images import PersistentSetInputView
|
||||||
|
from commands.server import autostart_comfy
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bot setup
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_prefix(bot, message):
|
||||||
|
"""Dynamic command prefix getter."""
|
||||||
|
msg = message.content.lower()
|
||||||
|
if msg.startswith(COMMAND_PREFIX):
|
||||||
|
return COMMAND_PREFIX
|
||||||
|
return COMMAND_PREFIX
|
||||||
|
|
||||||
|
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.message_content = True
|
||||||
|
intents.guilds = True
|
||||||
|
|
||||||
|
bot = commands.Bot(
|
||||||
|
command_prefix=get_prefix,
|
||||||
|
intents=intents,
|
||||||
|
help_command=CustomHelpCommand(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Event handlers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@bot.event
|
||||||
|
async def on_ready() -> None:
|
||||||
|
logger.info("Logged in as %s (ID: %s)", bot.user, bot.user.id)
|
||||||
|
if not hasattr(bot, "start_time"):
|
||||||
|
bot.start_time = datetime.now(timezone.utc)
|
||||||
|
cfg = getattr(bot, "config", None)
|
||||||
|
if cfg:
|
||||||
|
for row in get_all_images():
|
||||||
|
view = PersistentSetInputView(bot, cfg, row["id"])
|
||||||
|
bot.add_view(view, message_id=row["bot_reply_id"])
|
||||||
|
asyncio.create_task(autostart_comfy(cfg))
|
||||||
|
|
||||||
|
if not hasattr(bot, "status_monitor"):
|
||||||
|
log_ch = getattr(getattr(bot, "config", None), "log_channel_id", None)
|
||||||
|
if log_ch:
|
||||||
|
bot.status_monitor = StatusMonitor(bot, log_ch)
|
||||||
|
if hasattr(bot, "status_monitor"):
|
||||||
|
await bot.status_monitor.start()
|
||||||
|
|
||||||
|
|
||||||
|
@bot.event
|
||||||
|
async def on_disconnect() -> None:
|
||||||
|
logger.info("Discord connection closed")
|
||||||
|
|
||||||
|
|
||||||
|
@bot.event
|
||||||
|
async def on_resumed() -> None:
|
||||||
|
logger.info("Discord session resumed")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Startup helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def create_comfy(config: BotConfig) -> ComfyClient:
|
||||||
|
state_manager = WorkflowStateManager(state_file="current-workflow-changes.json")
|
||||||
|
workflow_manager = WorkflowManager()
|
||||||
|
return ComfyClient(
|
||||||
|
server_address=config.comfy_server,
|
||||||
|
workflow_manager=workflow_manager,
|
||||||
|
state_manager=state_manager,
|
||||||
|
logger=logger,
|
||||||
|
history_limit=config.comfy_history_limit,
|
||||||
|
output_path=config.comfy_output_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_autoload_last_workflow(client: ComfyClient) -> None:
|
||||||
|
"""Re-load the last used workflow from the workflows/ folder on startup."""
|
||||||
|
last_wf = client.state_manager.get_last_workflow_file()
|
||||||
|
if not last_wf:
|
||||||
|
return
|
||||||
|
wf_path = _PROJECT_ROOT / "workflows" / last_wf
|
||||||
|
if not wf_path.exists():
|
||||||
|
logger.warning("Last workflow file not found: %s", wf_path)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with open(wf_path, "r", encoding="utf-8") as f:
|
||||||
|
workflow = json.load(f)
|
||||||
|
# Restore template without clearing overrides on restart
|
||||||
|
client.workflow_manager.set_workflow_template(workflow)
|
||||||
|
logger.info("Auto-loaded last workflow: %s", last_wf)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to auto-load workflow %s: %s", last_wf, exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
try:
|
||||||
|
config: BotConfig = BotConfig.from_env()
|
||||||
|
logger.info("Configuration loaded")
|
||||||
|
except RuntimeError as exc:
|
||||||
|
logger.error("Configuration error: %s", exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
bot.comfy = await create_comfy(config)
|
||||||
|
bot.config = config
|
||||||
|
|
||||||
|
# Auto-load last workflow (restores template without clearing overrides)
|
||||||
|
_try_autoload_last_workflow(bot.comfy)
|
||||||
|
|
||||||
|
# Fallback: WORKFLOW_FILE env var
|
||||||
|
if not bot.comfy.get_workflow_template() and config.workflow_file:
|
||||||
|
try:
|
||||||
|
bot.comfy.load_workflow_from_file(config.workflow_file)
|
||||||
|
logger.info("Loaded workflow from %s", config.workflow_file)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to load workflow %s: %s", config.workflow_file, exc)
|
||||||
|
|
||||||
|
from user_state_registry import UserStateRegistry
|
||||||
|
bot.user_registry = UserStateRegistry(
|
||||||
|
settings_dir=_PROJECT_ROOT / "user_settings",
|
||||||
|
default_workflow=bot.comfy.get_workflow_template(),
|
||||||
|
)
|
||||||
|
|
||||||
|
init_db()
|
||||||
|
generation_db.init_db(_PROJECT_ROOT / "generation_history.db")
|
||||||
|
register_all_commands(bot, config)
|
||||||
|
logger.info("All commands registered")
|
||||||
|
|
||||||
|
coroutines = [bot.start(config.discord_bot_token)]
|
||||||
|
|
||||||
|
if config.web_enabled:
|
||||||
|
try:
|
||||||
|
import uvicorn
|
||||||
|
from web.deps import set_bot
|
||||||
|
from web.app import create_app
|
||||||
|
|
||||||
|
set_bot(bot)
|
||||||
|
fastapi_app = create_app()
|
||||||
|
|
||||||
|
uvi_config = uvicorn.Config(
|
||||||
|
fastapi_app,
|
||||||
|
host=config.web_host,
|
||||||
|
port=config.web_port,
|
||||||
|
log_level="info",
|
||||||
|
loop="none", # use existing event loop
|
||||||
|
)
|
||||||
|
uvi_server = uvicorn.Server(uvi_config)
|
||||||
|
coroutines.append(uvi_server.serve())
|
||||||
|
logger.info(
|
||||||
|
"Web UI enabled at http://%s:%d", config.web_host, config.web_port
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"uvicorn or fastapi not installed — web UI disabled. "
|
||||||
|
"pip install fastapi uvicorn[standard]"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*coroutines)
|
||||||
|
finally:
|
||||||
|
if hasattr(bot, "status_monitor"):
|
||||||
|
await bot.status_monitor.stop()
|
||||||
|
if hasattr(bot, "comfy"):
|
||||||
|
await bot.comfy.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Received interrupt, shutting down…")
|
||||||
604
comfy_client.py
Normal file
604
comfy_client.py
Normal file
@@ -0,0 +1,604 @@
|
|||||||
|
"""
|
||||||
|
comfy_client.py
|
||||||
|
================
|
||||||
|
|
||||||
|
Asynchronous client for the ComfyUI API.
|
||||||
|
|
||||||
|
Wraps ComfyUI's REST and WebSocket endpoints. Workflow template injection
|
||||||
|
is now handled by :class:`~workflow_inspector.WorkflowInspector`, so this
|
||||||
|
class only needs to:
|
||||||
|
|
||||||
|
1. Accept a workflow template (delegated to WorkflowManager).
|
||||||
|
2. Accept runtime overrides (delegated to WorkflowStateManager).
|
||||||
|
3. Build the final workflow via inspector.inject_overrides().
|
||||||
|
4. Queue it to ComfyUI, wait for completion via WebSocket, fetch outputs.
|
||||||
|
|
||||||
|
A ``{prompt_id: callback}`` map is maintained for future WebSocket
|
||||||
|
broadcasting (web UI phase). Discord commands still use the synchronous
|
||||||
|
await-and-return model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from workflow_inspector import WorkflowInspector
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyClient:
|
||||||
|
"""
|
||||||
|
Asynchronous ComfyUI client.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
server_address : str
|
||||||
|
``hostname:port`` of the ComfyUI server.
|
||||||
|
workflow_manager : WorkflowManager
|
||||||
|
Template storage (injected).
|
||||||
|
state_manager : WorkflowStateManager
|
||||||
|
Runtime overrides (injected).
|
||||||
|
logger : Optional[logging.Logger]
|
||||||
|
Logger for debug/info messages.
|
||||||
|
history_limit : int
|
||||||
|
Max recent generations to keep in the in-memory deque.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_address: str,
|
||||||
|
workflow_manager,
|
||||||
|
state_manager,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
*,
|
||||||
|
history_limit: int = 10,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
self.server_address = server_address.strip().rstrip("/")
|
||||||
|
self.client_id = str(uuid.uuid4())
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
self.protocol = "http"
|
||||||
|
self.ws_protocol = "ws"
|
||||||
|
|
||||||
|
self.workflow_manager = workflow_manager
|
||||||
|
self.state_manager = state_manager
|
||||||
|
self.inspector = WorkflowInspector()
|
||||||
|
self.output_path = output_path
|
||||||
|
|
||||||
|
# prompt_id → asyncio.Future for web-UI broadcast (Phase 4)
|
||||||
|
self._pending_callbacks: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
self._history = deque(maxlen=history_limit)
|
||||||
|
|
||||||
|
self.last_prompt_id: Optional[str] = None
|
||||||
|
self.last_seed: Optional[int] = None
|
||||||
|
self.total_generated: int = 0
|
||||||
|
|
||||||
|
self.logger = logger if logger else logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Session
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> aiohttp.ClientSession:
|
||||||
|
"""Lazily create and return an aiohttp session."""
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
self._session = aiohttp.ClientSession()
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Low-level REST helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _queue_prompt(
|
||||||
|
self,
|
||||||
|
prompt: dict[str, Any],
|
||||||
|
prompt_id: str,
|
||||||
|
ws_client_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Submit a workflow to the ComfyUI queue."""
|
||||||
|
payload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"client_id": ws_client_id if ws_client_id is not None else self.client_id,
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
}
|
||||||
|
url = f"{self.protocol}://{self.server_address}/prompt"
|
||||||
|
async with self.session.post(url, json=payload,
|
||||||
|
headers={"Content-Type": "application/json"}) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
async def _wait_for_execution(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
on_progress: Optional[Callable[[str, str], None]] = None,
|
||||||
|
ws_client_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Wait for a queued prompt to finish executing via WebSocket.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prompt_id : str
|
||||||
|
The prompt to wait for.
|
||||||
|
on_progress : Optional[Callable[[str, str], None]]
|
||||||
|
Called with ``(node_id, prompt_id)`` for each ``node_executing``
|
||||||
|
event. Pass ``None`` for Discord commands (no web broadcast
|
||||||
|
needed).
|
||||||
|
"""
|
||||||
|
client_id = ws_client_id if ws_client_id is not None else self.client_id
|
||||||
|
ws_url = (
|
||||||
|
f"{self.ws_protocol}://{self.server_address}/ws"
|
||||||
|
f"?clientId={client_id}"
|
||||||
|
)
|
||||||
|
async with websockets.connect(ws_url) as ws:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
out = await ws.recv()
|
||||||
|
if not isinstance(out, str):
|
||||||
|
continue
|
||||||
|
message = json.loads(out)
|
||||||
|
mtype = message.get("type")
|
||||||
|
|
||||||
|
if mtype == "executing":
|
||||||
|
data = message["data"]
|
||||||
|
node = data.get("node")
|
||||||
|
if node:
|
||||||
|
self.logger.debug("Executing node: %s", node)
|
||||||
|
if on_progress and data.get("prompt_id") == prompt_id:
|
||||||
|
try:
|
||||||
|
on_progress(node, prompt_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if data["node"] is None and data.get("prompt_id") == prompt_id:
|
||||||
|
self.logger.info("Execution complete for prompt %s", prompt_id)
|
||||||
|
break
|
||||||
|
|
||||||
|
elif mtype == "execution_success":
|
||||||
|
if message.get("data", {}).get("prompt_id") == prompt_id:
|
||||||
|
self.logger.info("execution_success for prompt %s", prompt_id)
|
||||||
|
break
|
||||||
|
|
||||||
|
elif mtype == "execution_error":
|
||||||
|
if message.get("data", {}).get("prompt_id") == prompt_id:
|
||||||
|
error = message.get("data", {}).get("exception_message", "unknown error")
|
||||||
|
raise RuntimeError(f"ComfyUI execution error: {error}")
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
self.logger.error("Error during execution wait: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _get_history(self, prompt_id: str) -> dict[str, Any]:
|
||||||
|
"""Retrieve execution history for a given prompt id."""
|
||||||
|
url = f"{self.protocol}://{self.server_address}/history/{prompt_id}"
|
||||||
|
async with self.session.get(url) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
async def _download_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
|
||||||
|
"""Download an image from ComfyUI and return raw bytes."""
|
||||||
|
url = f"{self.protocol}://{self.server_address}/view"
|
||||||
|
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||||
|
async with self.session.get(url, params=params) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.read()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Core generation pipeline
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _general_generate(
|
||||||
|
self,
|
||||||
|
workflow: dict[str, Any],
|
||||||
|
prompt_id: str,
|
||||||
|
on_progress: Optional[Callable[[str, str], None]] = None,
|
||||||
|
) -> tuple[List[bytes], List[dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Queue a workflow, wait for it to execute, then collect outputs.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[List[bytes], List[dict]]
|
||||||
|
``(images, videos)`` — images as raw bytes, videos as info dicts.
|
||||||
|
"""
|
||||||
|
ws_client_id = str(uuid.uuid4())
|
||||||
|
await self._queue_prompt(workflow, prompt_id, ws_client_id)
|
||||||
|
try:
|
||||||
|
await self._wait_for_execution(prompt_id, on_progress=on_progress, ws_client_id=ws_client_id)
|
||||||
|
except Exception:
|
||||||
|
self.logger.error("Execution failed for prompt %s", prompt_id)
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
history = await self._get_history(prompt_id)
|
||||||
|
if not history:
|
||||||
|
self.logger.warning("No history for prompt %s", prompt_id)
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
images: List[bytes] = []
|
||||||
|
videos: List[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for node_output in history.get(prompt_id, {}).get("outputs", {}).values():
|
||||||
|
for image_info in node_output.get("images", []):
|
||||||
|
name = image_info["filename"]
|
||||||
|
if name.rsplit(".", 1)[-1].lower() in {"mp4", "webm", "avi"}:
|
||||||
|
videos.append({
|
||||||
|
"video_name": name,
|
||||||
|
"video_subfolder": image_info.get("subfolder", ""),
|
||||||
|
"video_type": image_info.get("type", "output"),
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
data = await self._download_image(
|
||||||
|
name, image_info["subfolder"], image_info["type"]
|
||||||
|
)
|
||||||
|
images.append(data)
|
||||||
|
|
||||||
|
return images, videos
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# DB persistence helper
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _record_to_db(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
source: str,
|
||||||
|
user_label: Optional[str],
|
||||||
|
overrides: Dict[str, Any],
|
||||||
|
seed: Optional[int],
|
||||||
|
images: List[bytes],
|
||||||
|
videos: List[Dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Persist generation metadata and file blobs to SQLite. Never raises."""
|
||||||
|
try:
|
||||||
|
import generation_db
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
gen_id = generation_db.record_generation(
|
||||||
|
prompt_id, source, user_label, overrides, seed
|
||||||
|
)
|
||||||
|
for i, img_data in enumerate(images):
|
||||||
|
generation_db.record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||||
|
if videos and self.output_path:
|
||||||
|
for vid in videos:
|
||||||
|
vname = vid.get("video_name", "")
|
||||||
|
vsub = vid.get("video_subfolder", "")
|
||||||
|
vpath = (
|
||||||
|
_Path(self.output_path) / vsub / vname
|
||||||
|
if vsub
|
||||||
|
else _Path(self.output_path) / vname
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
generation_db.record_file(gen_id, vname, vpath.read_bytes())
|
||||||
|
except OSError as exc:
|
||||||
|
self.logger.warning(
|
||||||
|
"Could not read video for DB storage: %s: %s", vpath, exc
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self.logger.warning("Failed to record generation to DB: %s", exc)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public generation API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
on_progress: Optional[Callable[[str, str], None]] = None,
|
||||||
|
*,
|
||||||
|
source: str = "discord",
|
||||||
|
user_label: Optional[str] = None,
|
||||||
|
) -> tuple[List[bytes], str]:
|
||||||
|
"""
|
||||||
|
Generate images using the current workflow template with a text prompt.
|
||||||
|
|
||||||
|
Injects *prompt* (and optionally *negative_prompt*) via the inspector,
|
||||||
|
plus any currently pinned seed from the state manager. All other
|
||||||
|
overrides in the state manager are **not** applied here — use
|
||||||
|
:meth:`generate_image_with_workflow` for the full override set.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prompt : str
|
||||||
|
Positive prompt text.
|
||||||
|
negative_prompt : Optional[str]
|
||||||
|
Negative prompt text (optional).
|
||||||
|
on_progress : Optional[Callable]
|
||||||
|
Called with ``(node_id, prompt_id)`` for each executing node.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[List[bytes], str]
|
||||||
|
``(images, prompt_id)``
|
||||||
|
"""
|
||||||
|
template = self.workflow_manager.get_workflow_template()
|
||||||
|
if not template:
|
||||||
|
self.logger.warning("No workflow template set; cannot generate.")
|
||||||
|
return [], ""
|
||||||
|
|
||||||
|
overrides: Dict[str, Any] = {"prompt": prompt}
|
||||||
|
if negative_prompt is not None:
|
||||||
|
overrides["negative_prompt"] = negative_prompt
|
||||||
|
# Respect pinned seed from state manager
|
||||||
|
seed_pin = self.state_manager.get_seed()
|
||||||
|
if seed_pin is not None:
|
||||||
|
overrides["seed"] = seed_pin
|
||||||
|
|
||||||
|
workflow, applied = self.inspector.inject_overrides(template, overrides)
|
||||||
|
seed_used = applied.get("seed")
|
||||||
|
self.last_seed = seed_used
|
||||||
|
|
||||||
|
prompt_id = str(uuid.uuid4())
|
||||||
|
images, _videos = await self._general_generate(workflow, prompt_id, on_progress)
|
||||||
|
|
||||||
|
self.last_prompt_id = prompt_id
|
||||||
|
self.total_generated += 1
|
||||||
|
self._history.append({
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"seed": seed_used,
|
||||||
|
})
|
||||||
|
self._record_to_db(
|
||||||
|
prompt_id, source, user_label,
|
||||||
|
{"prompt": prompt, "negative_prompt": negative_prompt},
|
||||||
|
seed_used, images, [],
|
||||||
|
)
|
||||||
|
return images, prompt_id
|
||||||
|
|
||||||
|
async def generate_image_with_workflow(
|
||||||
|
self,
|
||||||
|
on_progress: Optional[Callable[[str, str], None]] = None,
|
||||||
|
*,
|
||||||
|
source: str = "discord",
|
||||||
|
user_label: Optional[str] = None,
|
||||||
|
) -> tuple[List[bytes], List[dict[str, Any]], str]:
|
||||||
|
"""
|
||||||
|
Generate images/videos from the current workflow applying ALL
|
||||||
|
overrides stored in the state manager.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[List[bytes], List[dict], str]
|
||||||
|
``(images, videos, prompt_id)``
|
||||||
|
"""
|
||||||
|
template = self.workflow_manager.get_workflow_template()
|
||||||
|
prompt_id = str(uuid.uuid4())
|
||||||
|
if not template:
|
||||||
|
self.logger.error("No workflow template set")
|
||||||
|
return [], [], prompt_id
|
||||||
|
|
||||||
|
overrides = self.state_manager.get_overrides()
|
||||||
|
workflow, applied = self.inspector.inject_overrides(template, overrides)
|
||||||
|
seed_used = applied.get("seed")
|
||||||
|
self.last_seed = seed_used
|
||||||
|
|
||||||
|
images, videos = await self._general_generate(workflow, prompt_id, on_progress)
|
||||||
|
|
||||||
|
self.last_prompt_id = prompt_id
|
||||||
|
self.total_generated += 1
|
||||||
|
prompt_str = overrides.get("prompt") or ""
|
||||||
|
neg_str = overrides.get("negative_prompt") or ""
|
||||||
|
self._history.append({
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"prompt": (prompt_str[:10] + "…") if len(prompt_str) > 10 else prompt_str or None,
|
||||||
|
"negative_prompt": (neg_str[:10] + "…") if len(neg_str) > 10 else neg_str or None,
|
||||||
|
"seed": seed_used,
|
||||||
|
})
|
||||||
|
self._record_to_db(prompt_id, source, user_label, overrides, seed_used, images, videos)
|
||||||
|
return images, videos, prompt_id
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Workflow template management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def set_workflow(self, workflow: dict[str, Any]) -> None:
|
||||||
|
"""Set the workflow template and clear all state overrides."""
|
||||||
|
self.workflow_manager.set_workflow_template(workflow)
|
||||||
|
self.state_manager.clear_overrides()
|
||||||
|
|
||||||
|
def load_workflow_from_file(self, path: str) -> None:
|
||||||
|
"""
|
||||||
|
Load a workflow template from a JSON file.
|
||||||
|
|
||||||
|
Also clears state overrides and records the filename in the state
|
||||||
|
manager for auto-load on restart.
|
||||||
|
"""
|
||||||
|
import json as _json
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
workflow = _json.load(f)
|
||||||
|
self.workflow_manager.set_workflow_template(workflow)
|
||||||
|
self.state_manager.clear_overrides()
|
||||||
|
from pathlib import Path
|
||||||
|
self.state_manager.set_last_workflow_file(Path(path).name)
|
||||||
|
|
||||||
|
def get_workflow_template(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""Return the current workflow template (or None)."""
|
||||||
|
return self.workflow_manager.get_workflow_template()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# State management convenience wrappers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_workflow_current_changes(self) -> dict[str, Any]:
|
||||||
|
"""Return all current overrides (backward-compat)."""
|
||||||
|
return self.state_manager.get_changes()
|
||||||
|
|
||||||
|
def set_workflow_current_changes(self, changes: dict[str, Any]) -> None:
|
||||||
|
"""Merge override changes (backward-compat)."""
|
||||||
|
self.state_manager.set_changes(changes, merge=True)
|
||||||
|
|
||||||
|
def set_workflow_current_prompt(self, prompt: str) -> None:
|
||||||
|
self.state_manager.set_prompt(prompt)
|
||||||
|
|
||||||
|
def set_workflow_current_negative_prompt(self, negative_prompt: str) -> None:
|
||||||
|
self.state_manager.set_negative_prompt(negative_prompt)
|
||||||
|
|
||||||
|
def set_workflow_current_input_image(self, input_image: str) -> None:
|
||||||
|
self.state_manager.set_input_image(input_image)
|
||||||
|
|
||||||
|
def get_current_workflow_prompt(self) -> Optional[str]:
|
||||||
|
return self.state_manager.get_prompt()
|
||||||
|
|
||||||
|
def get_current_workflow_negative_prompt(self) -> Optional[str]:
|
||||||
|
return self.state_manager.get_negative_prompt()
|
||||||
|
|
||||||
|
def get_current_workflow_input_image(self) -> Optional[str]:
|
||||||
|
return self.state_manager.get_input_image()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Image upload
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def upload_image(
|
||||||
|
self,
|
||||||
|
data: bytes,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
image_type: str = "input",
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Upload an image to ComfyUI via the /upload/image endpoint."""
|
||||||
|
url = f"{self.protocol}://{self.server_address}/upload/image"
|
||||||
|
form = aiohttp.FormData()
|
||||||
|
form.add_field("image", data, filename=filename,
|
||||||
|
content_type="application/octet-stream")
|
||||||
|
form.add_field("type", image_type)
|
||||||
|
form.add_field("overwrite", str(overwrite).lower())
|
||||||
|
async with self.session.post(url, data=form) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
try:
|
||||||
|
return await resp.json()
|
||||||
|
except aiohttp.ContentTypeError:
|
||||||
|
return {"status": await resp.text()}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# History
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_history(self) -> List[dict]:
|
||||||
|
"""Return a list of recently generated prompt records (from DB)."""
|
||||||
|
try:
|
||||||
|
from generation_db import get_history as db_get_history
|
||||||
|
return db_get_history(limit=self._history.maxlen or 50)
|
||||||
|
except Exception:
|
||||||
|
return list(self._history)
|
||||||
|
|
||||||
|
async def fetch_history_images(self, prompt_id: str) -> List[bytes]:
|
||||||
|
"""Re-download images for a previously generated prompt."""
|
||||||
|
history = await self._get_history(prompt_id)
|
||||||
|
images: List[bytes] = []
|
||||||
|
for node_output in history.get(prompt_id, {}).get("outputs", {}).values():
|
||||||
|
for image_info in node_output.get("images", []):
|
||||||
|
data = await self._download_image(
|
||||||
|
image_info["filename"],
|
||||||
|
image_info["subfolder"],
|
||||||
|
image_info["type"],
|
||||||
|
)
|
||||||
|
images.append(data)
|
||||||
|
return images
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Server info / queue
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get_system_stats(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""Fetch ComfyUI system stats (/system_stats)."""
|
||||||
|
try:
|
||||||
|
url = f"{self.protocol}://{self.server_address}/system_stats"
|
||||||
|
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
except Exception as exc:
|
||||||
|
self.logger.warning("Failed to fetch system stats: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_comfy_queue(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""Fetch the ComfyUI queue (/queue)."""
|
||||||
|
try:
|
||||||
|
url = f"{self.protocol}://{self.server_address}/queue"
|
||||||
|
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
except Exception as exc:
|
||||||
|
self.logger.warning("Failed to fetch comfy queue: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_queue_depth(self) -> int:
|
||||||
|
"""Return the total number of pending + running jobs in ComfyUI."""
|
||||||
|
q = await self.get_comfy_queue()
|
||||||
|
if q:
|
||||||
|
return len(q.get("queue_running", [])) + len(q.get("queue_pending", []))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def clear_queue(self) -> bool:
|
||||||
|
"""Clear all pending jobs from the ComfyUI queue."""
|
||||||
|
try:
|
||||||
|
url = f"{self.protocol}://{self.server_address}/queue"
|
||||||
|
async with self.session.post(
|
||||||
|
url,
|
||||||
|
json={"clear": True},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=5),
|
||||||
|
) as resp:
|
||||||
|
return resp.status in (200, 204)
|
||||||
|
except Exception as exc:
|
||||||
|
self.logger.warning("Failed to clear comfy queue: %s", exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def check_connection(self) -> bool:
|
||||||
|
"""Return True if the ComfyUI server is reachable."""
|
||||||
|
try:
|
||||||
|
url = f"{self.protocol}://{self.server_address}/system_stats"
|
||||||
|
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||||
|
return resp.status == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_models(self, model_type: str = "checkpoints") -> List[str]:
|
||||||
|
"""
|
||||||
|
Fetch available model names from ComfyUI.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_type : str
|
||||||
|
One of ``"checkpoints"``, ``"loras"``, etc.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = f"{self.protocol}://{self.server_address}/object_info"
|
||||||
|
async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
info = await resp.json()
|
||||||
|
if model_type == "checkpoints":
|
||||||
|
node = info.get("CheckpointLoaderSimple", {})
|
||||||
|
return node.get("input", {}).get("required", {}).get("ckpt_name", [None])[0] or []
|
||||||
|
elif model_type == "loras":
|
||||||
|
node = info.get("LoraLoader", {})
|
||||||
|
return node.get("input", {}).get("required", {}).get("lora_name", [None])[0] or []
|
||||||
|
return []
|
||||||
|
except Exception as exc:
|
||||||
|
self.logger.warning("Failed to fetch models (%s): %s", model_type, exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Lifecycle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the underlying aiohttp session."""
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "ComfyClient":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||||
|
await self.close()
|
||||||
64
commands/__init__.py
Normal file
64
commands/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
commands package
|
||||||
|
================
|
||||||
|
|
||||||
|
Discord bot commands for the ComfyUI bot.
|
||||||
|
|
||||||
|
This package contains all command handlers organized by functionality:
|
||||||
|
- generation: Image/video generation commands (generate, workflow-gen, rerun, cancel)
|
||||||
|
- workflow: Workflow management commands
|
||||||
|
- upload: Image upload commands
|
||||||
|
- history: History viewing and retrieval commands
|
||||||
|
- workflow_changes: Runtime workflow parameter management (prompt, seed, etc.)
|
||||||
|
- utility: Quality-of-life commands (ping, status, comfy-stats, comfy-queue, uptime)
|
||||||
|
- presets: Named workflow preset management
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
from config import BotConfig
|
||||||
|
|
||||||
|
from .generation import setup_generation_commands
|
||||||
|
from .input_images import setup_input_image_commands
|
||||||
|
from .server import setup_server_commands
|
||||||
|
from .workflow import setup_workflow_commands
|
||||||
|
from .history import setup_history_commands
|
||||||
|
from .workflow_changes import setup_workflow_changes_commands
|
||||||
|
from .utility import setup_utility_commands
|
||||||
|
from .presets import setup_preset_commands
|
||||||
|
from .help_command import CustomHelpCommand
|
||||||
|
|
||||||
|
|
||||||
|
def register_all_commands(bot, config: BotConfig):
|
||||||
|
"""
|
||||||
|
Register all bot commands.
|
||||||
|
|
||||||
|
This function should be called once during bot initialization to set up
|
||||||
|
all command handlers.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bot : commands.Bot
|
||||||
|
The Discord bot instance.
|
||||||
|
config : BotConfig
|
||||||
|
The bot configuration object containing environment settings.
|
||||||
|
"""
|
||||||
|
setup_generation_commands(bot, config)
|
||||||
|
setup_input_image_commands(bot, config)
|
||||||
|
setup_server_commands(bot, config)
|
||||||
|
setup_workflow_commands(bot)
|
||||||
|
setup_history_commands(bot)
|
||||||
|
setup_workflow_changes_commands(bot)
|
||||||
|
setup_utility_commands(bot)
|
||||||
|
setup_preset_commands(bot)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"register_all_commands",
|
||||||
|
"setup_generation_commands",
|
||||||
|
"setup_input_image_commands",
|
||||||
|
"setup_workflow_commands",
|
||||||
|
"setup_history_commands",
|
||||||
|
"setup_workflow_changes_commands",
|
||||||
|
"setup_utility_commands",
|
||||||
|
"setup_preset_commands",
|
||||||
|
"CustomHelpCommand",
|
||||||
|
]
|
||||||
389
commands/generation.py
Normal file
389
commands/generation.py
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
"""
|
||||||
|
commands/generation.py
|
||||||
|
======================
|
||||||
|
|
||||||
|
Image and video generation commands for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
Jobs are submitted directly to ComfyUI (no internal SerialJobQueue).
|
||||||
|
ComfyUI's own queue handles ordering. Each Discord command waits for its
|
||||||
|
prompt_id to complete via WebSocket and then replies with the result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aiohttp # type: ignore
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
aiohttp = None # type: ignore
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from config import ARG_PROMPT_KEY, ARG_NEG_PROMPT_KEY, ARG_QUEUE_KEY, MAX_IMAGES_PER_RESPONSE
|
||||||
|
from discord_utils import require_comfy_client, convert_image_bytes_to_discord_files
|
||||||
|
from media_uploader import flush_pending
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _safe_reply(
|
||||||
|
ctx: commands.Context,
|
||||||
|
*,
|
||||||
|
content: str | None = None,
|
||||||
|
files: list[discord.File] | None = None,
|
||||||
|
mention_author: bool = True,
|
||||||
|
delete_after: float | None = None,
|
||||||
|
tries: int = 4,
|
||||||
|
base_delay: float = 1.0,
|
||||||
|
):
|
||||||
|
"""Reply to Discord with retries for transient network/Discord errors."""
|
||||||
|
delay = base_delay
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(1, tries + 1):
|
||||||
|
try:
|
||||||
|
return await ctx.reply(
|
||||||
|
content=content,
|
||||||
|
files=files or [],
|
||||||
|
mention_author=mention_author,
|
||||||
|
delete_after=delete_after,
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
last_exc = exc
|
||||||
|
transient = False
|
||||||
|
|
||||||
|
if isinstance(exc, asyncio.TimeoutError):
|
||||||
|
transient = True
|
||||||
|
elif isinstance(exc, OSError) and getattr(exc, "winerror", None) in {
|
||||||
|
64, 121, 1231, 10053, 10054,
|
||||||
|
}:
|
||||||
|
transient = True
|
||||||
|
|
||||||
|
if aiohttp is not None:
|
||||||
|
try:
|
||||||
|
if isinstance(exc, (aiohttp.ClientOSError, aiohttp.ClientConnectionError)):
|
||||||
|
transient = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(exc, discord.HTTPException):
|
||||||
|
status = getattr(exc, "status", None)
|
||||||
|
if status is None or status >= 500 or status == 429:
|
||||||
|
transient = True
|
||||||
|
|
||||||
|
if (not transient) or attempt == tries:
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Transient error sending Discord message (attempt %d/%d): %s: %s",
|
||||||
|
attempt, tries, type(exc).__name__, exc,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
delay *= 2
|
||||||
|
|
||||||
|
raise last_exc # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_line(bot) -> str:
|
||||||
|
"""Return a formatted seed line if a seed was tracked, else empty string."""
|
||||||
|
seed = getattr(bot.comfy, "last_seed", None)
|
||||||
|
return f"\nSeed: `{seed}`" if seed is not None else ""
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_generate(ctx: commands.Context, bot, prompt_text: str, negative_text: Optional[str]):
|
||||||
|
"""Execute a prompt-based generation and reply with results."""
|
||||||
|
images, prompt_id = await bot.comfy.generate_image(
|
||||||
|
prompt_text, negative_text,
|
||||||
|
source="discord", user_label=ctx.author.display_name,
|
||||||
|
)
|
||||||
|
if not images:
|
||||||
|
await ctx.reply(
|
||||||
|
"No images were generated. Please try again with a different prompt.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
files = convert_image_bytes_to_discord_files(
|
||||||
|
images, max_files=MAX_IMAGES_PER_RESPONSE, prefix="generated"
|
||||||
|
)
|
||||||
|
response_text = f"Generated {len(images)} image(s). Prompt ID: `{prompt_id}`{_seed_line(bot)}"
|
||||||
|
await _safe_reply(ctx, content=response_text, files=files, mention_author=True)
|
||||||
|
|
||||||
|
asyncio.create_task(flush_pending(
|
||||||
|
Path(bot.config.comfy_output_path),
|
||||||
|
bot.config.media_upload_user,
|
||||||
|
bot.config.media_upload_pass,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_workflow(ctx: commands.Context, bot, config):
|
||||||
|
"""Execute a workflow-based generation and reply with results."""
|
||||||
|
logger.info("Executing workflow generation")
|
||||||
|
await ctx.reply("Executing workflow…", mention_author=False, delete_after=5.0)
|
||||||
|
images, videos, prompt_id = await bot.comfy.generate_image_with_workflow(
|
||||||
|
source="discord", user_label=ctx.author.display_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not images and not videos:
|
||||||
|
await ctx.reply(
|
||||||
|
"No images or videos were generated. Check the workflow and ComfyUI logs.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
seed_info = _seed_line(bot)
|
||||||
|
|
||||||
|
if videos:
|
||||||
|
output_path = config.comfy_output_path
|
||||||
|
video_file = None
|
||||||
|
for video_info in videos:
|
||||||
|
video_name = video_info.get("video_name")
|
||||||
|
video_subfolder = video_info.get("video_subfolder", "")
|
||||||
|
if video_name:
|
||||||
|
video_path = (
|
||||||
|
Path(output_path) / video_subfolder / video_name
|
||||||
|
if video_subfolder
|
||||||
|
else Path(output_path) / video_name
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
video_file = discord.File(
|
||||||
|
BytesIO(video_path.read_bytes()), filename=video_name
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to read video %s: %s", video_path, exc)
|
||||||
|
|
||||||
|
if video_file:
|
||||||
|
response_text = (
|
||||||
|
f"Generated {len(images)} image(s) and a video. "
|
||||||
|
f"Prompt ID: `{prompt_id}`{seed_info}"
|
||||||
|
)
|
||||||
|
await _safe_reply(ctx, content=response_text, files=[video_file], mention_author=True)
|
||||||
|
else:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Generated output but failed to read video file. "
|
||||||
|
f"Prompt ID: `{prompt_id}`{seed_info}",
|
||||||
|
mention_author=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
files = convert_image_bytes_to_discord_files(
|
||||||
|
images, max_files=MAX_IMAGES_PER_RESPONSE, prefix="generated"
|
||||||
|
)
|
||||||
|
response_text = (
|
||||||
|
f"Generated {len(images)} image(s) using workflow. "
|
||||||
|
f"Prompt ID: `{prompt_id}`{seed_info}"
|
||||||
|
)
|
||||||
|
await _safe_reply(ctx, content=response_text, files=files, mention_author=True)
|
||||||
|
|
||||||
|
asyncio.create_task(flush_pending(
|
||||||
|
Path(config.comfy_output_path),
|
||||||
|
config.media_upload_user,
|
||||||
|
config.media_upload_pass,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def setup_generation_commands(bot, config):
|
||||||
|
"""Register generation commands with the bot."""
|
||||||
|
|
||||||
|
@bot.command(name="test", extras={"category": "Generation"})
|
||||||
|
async def test_command(ctx: commands.Context) -> None:
|
||||||
|
"""A simple test command to verify the bot is working."""
|
||||||
|
await ctx.reply(
|
||||||
|
"The bot is working! Use `ttr!generate` to create images.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@bot.command(name="generate", aliases=["gen"], extras={"category": "Generation"})
|
||||||
|
@require_comfy_client
|
||||||
|
async def generate(ctx: commands.Context, *, args: str = "") -> None:
|
||||||
|
"""
|
||||||
|
Generate images using ComfyUI.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
ttr!generate prompt:<your prompt> negative_prompt:<your negatives>
|
||||||
|
|
||||||
|
The ``prompt:`` keyword is required. ``negative_prompt:`` is optional.
|
||||||
|
"""
|
||||||
|
prompt_text: Optional[str] = None
|
||||||
|
negative_text: Optional[str] = None
|
||||||
|
|
||||||
|
if args:
|
||||||
|
if ARG_PROMPT_KEY in args:
|
||||||
|
parts = args.split(ARG_PROMPT_KEY, 1)[1]
|
||||||
|
if ARG_NEG_PROMPT_KEY in parts:
|
||||||
|
p, n = parts.split(ARG_NEG_PROMPT_KEY, 1)
|
||||||
|
prompt_text = p.strip()
|
||||||
|
negative_text = n.strip() or None
|
||||||
|
else:
|
||||||
|
prompt_text = parts.strip()
|
||||||
|
else:
|
||||||
|
prompt_text = args.strip()
|
||||||
|
|
||||||
|
if not prompt_text:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Please specify a prompt: `{ARG_PROMPT_KEY}<your prompt>`.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
bot.last_gen = {"mode": "prompt", "prompt": prompt_text, "negative": negative_text}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Show queue position from ComfyUI before waiting
|
||||||
|
depth = await bot.comfy.get_queue_depth()
|
||||||
|
pos = depth + 1
|
||||||
|
ack = await ctx.reply(
|
||||||
|
f"Queued ✅ (ComfyUI position: ~{pos})",
|
||||||
|
mention_author=False,
|
||||||
|
delete_after=30.0,
|
||||||
|
)
|
||||||
|
await _run_generate(ctx, bot, prompt_text, negative_text)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Error generating image")
|
||||||
|
await ctx.reply(
|
||||||
|
f"An error occurred: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@bot.command(
|
||||||
|
name="workflow-gen",
|
||||||
|
aliases=["workflow-generate", "wfg"],
|
||||||
|
extras={"category": "Generation"},
|
||||||
|
)
|
||||||
|
@require_comfy_client
|
||||||
|
async def generate_workflow_command(ctx: commands.Context, *, args: str = "") -> None:
|
||||||
|
"""
|
||||||
|
Generate using the currently loaded workflow template.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
ttr!workflow-gen
|
||||||
|
ttr!workflow-gen queue:<number>
|
||||||
|
"""
|
||||||
|
bot.last_gen = {"mode": "workflow", "prompt": None, "negative": None}
|
||||||
|
|
||||||
|
# Handle batch queue parameter
|
||||||
|
if ARG_QUEUE_KEY in args:
|
||||||
|
number_part = args.split(ARG_QUEUE_KEY, 1)[1].strip()
|
||||||
|
if number_part.isdigit():
|
||||||
|
queue_times = int(number_part)
|
||||||
|
if queue_times > 1:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Queuing {queue_times} workflow runs…",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
for i in range(queue_times):
|
||||||
|
try:
|
||||||
|
depth = await bot.comfy.get_queue_depth()
|
||||||
|
pos = depth + 1
|
||||||
|
await ctx.reply(
|
||||||
|
f"Queued run {i+1}/{queue_times} ✅ (ComfyUI position: ~{pos})",
|
||||||
|
mention_author=False,
|
||||||
|
delete_after=30.0,
|
||||||
|
)
|
||||||
|
await _run_workflow(ctx, bot, config)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Error on workflow run %d", i + 1)
|
||||||
|
await ctx.reply(
|
||||||
|
f"Error on run {i+1}: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
await ctx.reply(
|
||||||
|
"Please provide a number greater than 1 for queueing multiple runs.",
|
||||||
|
mention_author=False,
|
||||||
|
delete_after=30.0,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Invalid queue parameter. Use `{ARG_QUEUE_KEY}<number>`.",
|
||||||
|
mention_author=False,
|
||||||
|
delete_after=30.0,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
depth = await bot.comfy.get_queue_depth()
|
||||||
|
pos = depth + 1
|
||||||
|
await ctx.reply(
|
||||||
|
f"Queued ✅ (ComfyUI position: ~{pos})",
|
||||||
|
mention_author=False,
|
||||||
|
delete_after=30.0,
|
||||||
|
)
|
||||||
|
await _run_workflow(ctx, bot, config)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Error generating with workflow")
|
||||||
|
await ctx.reply(
|
||||||
|
f"An error occurred: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@bot.command(name="rerun", aliases=["rr"], extras={"category": "Generation"})
|
||||||
|
@require_comfy_client
|
||||||
|
async def rerun_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Re-run the last generation with the same parameters.
|
||||||
|
|
||||||
|
Re-submits the most recent ``ttr!generate`` or ``ttr!workflow-gen``
|
||||||
|
with the same mode and prompt. Current state overrides (seed,
|
||||||
|
input_image, etc.) are applied at execution time.
|
||||||
|
"""
|
||||||
|
last = getattr(bot, "last_gen", None)
|
||||||
|
if last is None:
|
||||||
|
await ctx.reply(
|
||||||
|
"No previous generation to rerun.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
depth = await bot.comfy.get_queue_depth()
|
||||||
|
pos = depth + 1
|
||||||
|
await ctx.reply(
|
||||||
|
f"Rerun queued ✅ (ComfyUI position: ~{pos})",
|
||||||
|
mention_author=False,
|
||||||
|
delete_after=30.0,
|
||||||
|
)
|
||||||
|
if last["mode"] == "prompt":
|
||||||
|
await _run_generate(ctx, bot, last["prompt"], last["negative"])
|
||||||
|
else:
|
||||||
|
await _run_workflow(ctx, bot, config)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Error queueing rerun")
|
||||||
|
await ctx.reply(
|
||||||
|
f"An error occurred: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@bot.command(name="cancel", extras={"category": "Generation"})
|
||||||
|
@require_comfy_client
|
||||||
|
async def cancel_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Clear all pending jobs from the ComfyUI queue.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
ttr!cancel
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ok = await bot.comfy.clear_queue()
|
||||||
|
if ok:
|
||||||
|
await ctx.reply("ComfyUI queue cleared.", mention_author=False)
|
||||||
|
else:
|
||||||
|
await ctx.reply(
|
||||||
|
"Failed to clear the ComfyUI queue (server may have returned an error).",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
await ctx.reply(f"Error: {exc}", mention_author=False)
|
||||||
134
commands/help_command.py
Normal file
134
commands/help_command.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
commands/help_command.py
|
||||||
|
========================
|
||||||
|
|
||||||
|
Custom help command for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
Replaces discord.py's default help with a categorised listing that
|
||||||
|
automatically includes every registered command.
|
||||||
|
|
||||||
|
How it works
|
||||||
|
------------
|
||||||
|
Each ``@bot.command()`` decorator should carry an ``extras`` dict with a
|
||||||
|
``"category"`` key:
|
||||||
|
|
||||||
|
@bot.command(name="my-command", extras={"category": "Generation"})
|
||||||
|
async def my_command(ctx):
|
||||||
|
\"""One-line brief shown in the listing.
|
||||||
|
|
||||||
|
Longer description shown in ttr!help my-command.
|
||||||
|
\"""
|
||||||
|
|
||||||
|
The first line of the docstring becomes the brief shown in the main
|
||||||
|
listing. The full docstring is shown when the user asks for per-command
|
||||||
|
detail. Commands without a category appear under **Other**.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
ttr!help — list all commands grouped by category
|
||||||
|
ttr!help <command> — detailed help for a specific command
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import List, Mapping, Optional
|
||||||
|
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
|
||||||
|
# Order in which categories appear in the full help listing.
|
||||||
|
# Any category not listed here appears at the end, sorted alphabetically.
|
||||||
|
CATEGORY_ORDER = ["Generation", "Workflow", "Upload", "History", "Presets", "Utility"]
|
||||||
|
|
||||||
|
|
||||||
|
def _category_sort_key(name: str) -> tuple:
|
||||||
|
"""Return a sort key that respects CATEGORY_ORDER, then alphabetical."""
|
||||||
|
try:
|
||||||
|
return (CATEGORY_ORDER.index(name), name)
|
||||||
|
except ValueError:
|
||||||
|
return (len(CATEGORY_ORDER), name)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomHelpCommand(commands.HelpCommand):
|
||||||
|
"""
|
||||||
|
Categorised help command.
|
||||||
|
|
||||||
|
Groups commands by the ``"category"`` value in their ``extras`` dict.
|
||||||
|
Commands that omit this appear under **Other**.
|
||||||
|
|
||||||
|
Adding a new command to the help output requires no changes here —
|
||||||
|
just set ``extras={"category": "..."}`` on the decorator and write a
|
||||||
|
descriptive docstring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Main listing — ttr!help
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def send_bot_help(
|
||||||
|
self,
|
||||||
|
mapping: Mapping[Optional[commands.Cog], List[commands.Command]],
|
||||||
|
) -> None:
|
||||||
|
"""Send the full command listing grouped by category."""
|
||||||
|
# Collect all visible commands across every cog / None bucket
|
||||||
|
all_commands: List[commands.Command] = []
|
||||||
|
for cmds in mapping.values():
|
||||||
|
filtered = await self.filter_commands(cmds)
|
||||||
|
all_commands.extend(filtered)
|
||||||
|
|
||||||
|
# Group by category
|
||||||
|
categories: dict[str, list[commands.Command]] = defaultdict(list)
|
||||||
|
for cmd in all_commands:
|
||||||
|
cat = cmd.extras.get("category", "Other")
|
||||||
|
categories[cat].append(cmd)
|
||||||
|
|
||||||
|
prefix = self.context.prefix
|
||||||
|
lines: list[str] = [f"**Commands** — prefix: `{prefix}`\n"]
|
||||||
|
|
||||||
|
for cat in sorted(categories.keys(), key=_category_sort_key):
|
||||||
|
cmds = sorted(categories[cat], key=lambda c: c.name)
|
||||||
|
lines.append(f"**{cat}**")
|
||||||
|
for cmd in cmds:
|
||||||
|
aliases = (
|
||||||
|
f" ({', '.join(cmd.aliases)})" if cmd.aliases else ""
|
||||||
|
)
|
||||||
|
brief = cmd.short_doc or "No description."
|
||||||
|
lines.append(f" `{cmd.name}`{aliases} — {brief}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.append(
|
||||||
|
f"Use `{prefix}help <command>` for details on a specific command."
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.get_destination().send("\n".join(lines))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Per-command detail — ttr!help <command>
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def send_command_help(self, command: commands.Command) -> None:
|
||||||
|
"""Send detailed help for a single command."""
|
||||||
|
prefix = self.context.prefix
|
||||||
|
header = f"`{prefix}{command.name}`"
|
||||||
|
if command.aliases:
|
||||||
|
alias_list = ", ".join(f"`{a}`" for a in command.aliases)
|
||||||
|
header += f" (aliases: {alias_list})"
|
||||||
|
|
||||||
|
category = command.extras.get("category", "Other")
|
||||||
|
lines: list[str] = [header, f"Category: **{category}**", ""]
|
||||||
|
|
||||||
|
if command.help:
|
||||||
|
lines.append(command.help.strip())
|
||||||
|
else:
|
||||||
|
lines.append("No description available.")
|
||||||
|
|
||||||
|
await self.get_destination().send("\n".join(lines))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Error — unknown command name
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def send_error_message(self, error: str) -> None:
|
||||||
|
"""Forward the error text to the channel."""
|
||||||
|
await self.get_destination().send(error)
|
||||||
169
commands/history.py
Normal file
169
commands/history.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
commands/history.py
|
||||||
|
===================
|
||||||
|
|
||||||
|
History management commands for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
This module contains commands for viewing and retrieving past generation
|
||||||
|
results from the bot's history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from config import MAX_IMAGES_PER_RESPONSE
|
||||||
|
from discord_utils import require_comfy_client, truncate_text, convert_image_bytes_to_discord_files
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_history_commands(bot):
|
||||||
|
"""
|
||||||
|
Register history management commands with the bot.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bot : commands.Bot
|
||||||
|
The Discord bot instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@bot.command(name="history", extras={"category": "History"})
|
||||||
|
@require_comfy_client
|
||||||
|
async def history_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Show a list of recently generated prompts.
|
||||||
|
|
||||||
|
The bot keeps a rolling history of the last few generations. Each
|
||||||
|
entry lists the prompt id along with the positive and negative
|
||||||
|
prompt texts. You can retrieve the images from a previous
|
||||||
|
generation with the ``ttr!gethistory <prompt_id>`` command.
|
||||||
|
"""
|
||||||
|
hist = bot.comfy.get_history()
|
||||||
|
if not hist:
|
||||||
|
await ctx.reply(
|
||||||
|
"No history available yet. Generate something first!",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build a human readable list
|
||||||
|
lines = ["Here are the most recent generations (oldest first):"]
|
||||||
|
for entry in hist:
|
||||||
|
pid = entry.get("prompt_id", "unknown")
|
||||||
|
prompt = entry.get("prompt") or ""
|
||||||
|
neg = entry.get("negative_prompt") or ""
|
||||||
|
# Truncate long prompts for readability
|
||||||
|
lines.append(
|
||||||
|
f"• ID: {pid} | prompt: '{truncate_text(prompt, 60)}' | negative: '{truncate_text(neg, 60)}'"
|
||||||
|
)
|
||||||
|
await ctx.reply("\n".join(lines), mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(name="get-history", aliases=["gethistory", "gh"], extras={"category": "History"})
|
||||||
|
@require_comfy_client
|
||||||
|
async def get_history_command(ctx: commands.Context, *, arg: str = "") -> None:
|
||||||
|
"""
|
||||||
|
Retrieve images from a previous generation, or search history by keyword.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!gethistory <prompt_id_or_index>
|
||||||
|
ttr!gethistory search:<keyword>
|
||||||
|
|
||||||
|
Provide either the prompt id returned in the generation response
|
||||||
|
(shown in `ttr!history`) or the 1‑based index into the history
|
||||||
|
list. The bot will fetch the images associated with that
|
||||||
|
generation and resend them. If no images are found, you will be
|
||||||
|
notified.
|
||||||
|
|
||||||
|
Use ``search:<keyword>`` to filter history by prompt text, checkpoint
|
||||||
|
name, seed value, or any other override field.
|
||||||
|
"""
|
||||||
|
if not arg:
|
||||||
|
await ctx.reply(
|
||||||
|
"Please provide a prompt id, history index, or `search:<keyword>`. See `ttr!history` for a list.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle search:<keyword>
|
||||||
|
lower_arg = arg.lower()
|
||||||
|
if lower_arg.startswith("search:"):
|
||||||
|
keyword = arg[len("search:"):].strip()
|
||||||
|
if not keyword:
|
||||||
|
await ctx.reply("Please provide a keyword after `search:`.", mention_author=False)
|
||||||
|
return
|
||||||
|
from generation_db import search_history_for_user, get_history as db_get_history
|
||||||
|
# Use get_history for Discord since Discord bot doesn't have per-user context like the web UI
|
||||||
|
hist = db_get_history(limit=50)
|
||||||
|
matches = [
|
||||||
|
e for e in hist
|
||||||
|
if keyword.lower() in str(e.get("overrides", {})).lower()
|
||||||
|
]
|
||||||
|
if not matches:
|
||||||
|
await ctx.reply(f"No history entries matching `{keyword}`.", mention_author=False)
|
||||||
|
return
|
||||||
|
lines = [f"**History matching `{keyword}`** ({len(matches)} result(s))"]
|
||||||
|
for entry in matches[:10]:
|
||||||
|
pid = entry.get("prompt_id", "unknown")
|
||||||
|
overrides = entry.get("overrides") or {}
|
||||||
|
prompt = str(overrides.get("prompt") or "")
|
||||||
|
lines.append(
|
||||||
|
f"• `{pid[:12]}…` | {truncate_text(prompt, 60) if prompt else '(no prompt)'}"
|
||||||
|
)
|
||||||
|
if len(matches) > 10:
|
||||||
|
lines.append(f"_(showing first 10 of {len(matches)})_")
|
||||||
|
await ctx.reply("\n".join(lines), mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine whether arg refers to an index or an id
|
||||||
|
target_id: Optional[str] = None
|
||||||
|
hist = bot.comfy.get_history()
|
||||||
|
|
||||||
|
# If arg is a digit, interpret as 1‑based index
|
||||||
|
if arg.isdigit():
|
||||||
|
idx = int(arg) - 1
|
||||||
|
if idx < 0 or idx >= len(hist):
|
||||||
|
await ctx.reply(
|
||||||
|
f"Index out of range. There are {len(hist)} entries in history.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
target_id = hist[idx]["prompt_id"]
|
||||||
|
else:
|
||||||
|
# Otherwise treat as an explicit prompt id
|
||||||
|
target_id = arg.strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
images = await bot.comfy.fetch_history_images(target_id)
|
||||||
|
if not images:
|
||||||
|
await ctx.reply(
|
||||||
|
f"No images found for prompt id `{target_id}`.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for idx, img_bytes in enumerate(images):
|
||||||
|
if idx >= MAX_IMAGES_PER_RESPONSE:
|
||||||
|
break
|
||||||
|
file_obj = BytesIO(img_bytes)
|
||||||
|
file_obj.seek(0)
|
||||||
|
files.append(discord.File(file_obj, filename=f"history_{target_id}_{idx+1}.png"))
|
||||||
|
|
||||||
|
await ctx.reply(
|
||||||
|
content=f"Here are the images for prompt id `{target_id}`:",
|
||||||
|
files=files,
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to fetch history for %s", target_id)
|
||||||
|
await ctx.reply(
|
||||||
|
f"An error occurred: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
178
commands/input_images.py
Normal file
178
commands/input_images.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
commands/input_images.py
|
||||||
|
========================
|
||||||
|
|
||||||
|
Channel-backed input image management.
|
||||||
|
|
||||||
|
Images uploaded to the designated `comfy-input` channel get a persistent
|
||||||
|
"✅ Set as input" button posted by the bot — one reply per attachment so
|
||||||
|
every image in a multi-image message is independently selectable.
|
||||||
|
|
||||||
|
Persistent views survive bot restarts: on_ready re-registers every view
|
||||||
|
stored in the SQLite database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from image_utils import compress_to_discord_limit
|
||||||
|
from input_image_db import (
|
||||||
|
activate_image_for_slot,
|
||||||
|
get_all_images,
|
||||||
|
upsert_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PersistentSetInputView(discord.ui.View):
|
||||||
|
"""
|
||||||
|
A persistent view that survives bot restarts.
|
||||||
|
|
||||||
|
One instance is created per DB row (i.e. per attachment).
|
||||||
|
The button's custom_id encodes the row id so the callback can look
|
||||||
|
up the exact filename to download.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bot, config, row_id: int):
|
||||||
|
super().__init__(timeout=None)
|
||||||
|
self._bot = bot
|
||||||
|
self._config = config
|
||||||
|
self._row_id = row_id
|
||||||
|
|
||||||
|
btn = discord.ui.Button(
|
||||||
|
label="✅ Set as input",
|
||||||
|
style=discord.ButtonStyle.success,
|
||||||
|
custom_id=f"set_input:{row_id}",
|
||||||
|
)
|
||||||
|
btn.callback = self._set_callback
|
||||||
|
self.add_item(btn)
|
||||||
|
|
||||||
|
async def _set_callback(self, interaction: discord.Interaction) -> None:
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
try:
|
||||||
|
filename = activate_image_for_slot(
|
||||||
|
self._row_id, "input_image", self._config.comfy_input_path
|
||||||
|
)
|
||||||
|
self._bot.comfy.state_manager.set_override("input_image", filename)
|
||||||
|
await interaction.followup.send(
|
||||||
|
f"✅ Input image set to `{filename}`", ephemeral=True
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("set_input button failed for row %s", self._row_id)
|
||||||
|
await interaction.followup.send(f"❌ Error: {exc}", ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def _register_attachment(bot, config, message: discord.Message, attachment: discord.Attachment) -> None:
|
||||||
|
"""Post a reply with the image preview, a Set-as-input button, and record it in the DB."""
|
||||||
|
logger.info("[_register_attachment] Start")
|
||||||
|
original_data = await attachment.read()
|
||||||
|
original_filename = attachment.filename
|
||||||
|
logger.info("[_register_attachment] Reading attachment")
|
||||||
|
|
||||||
|
# Compress only for the Discord re-send (8 MiB bot limit)
|
||||||
|
send_data, send_filename = compress_to_discord_limit(original_data, original_filename)
|
||||||
|
|
||||||
|
file = discord.File(io.BytesIO(send_data), filename=send_filename)
|
||||||
|
reply = await message.channel.send(f"`{original_filename}`", file=file)
|
||||||
|
|
||||||
|
# Store original quality bytes in DB
|
||||||
|
row_id = upsert_image(message.id, reply.id, message.channel.id, original_filename, image_data=original_data)
|
||||||
|
view = PersistentSetInputView(bot, config, row_id)
|
||||||
|
bot.add_view(view, message_id=reply.id)
|
||||||
|
logger.info("[_register_attachment] Done")
|
||||||
|
await reply.edit(view=view)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_input_image_commands(bot, config=None):
|
||||||
|
"""Register input image commands and the on_message listener."""
|
||||||
|
|
||||||
|
@bot.listen("on_message")
|
||||||
|
async def _on_input_channel_message(message: discord.Message) -> None:
|
||||||
|
"""Watch the comfy-input channel and attach a Set-as-input button to every image upload."""
|
||||||
|
if config is None:
|
||||||
|
logger.warning("[_on_input_channel_message] Config is none")
|
||||||
|
return
|
||||||
|
if message.channel.id != config.comfy_input_channel_id:
|
||||||
|
return
|
||||||
|
if message.author.bot:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_attachments = [
|
||||||
|
a for a in message.attachments
|
||||||
|
if Path(a.filename).suffix.lower() in IMAGE_EXTENSIONS
|
||||||
|
]
|
||||||
|
if not image_attachments:
|
||||||
|
logger.info("[_on_input_channel_message] No image attachments")
|
||||||
|
return
|
||||||
|
|
||||||
|
for attachment in image_attachments:
|
||||||
|
await _register_attachment(bot, config, message, attachment)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await message.delete()
|
||||||
|
except discord.Forbidden:
|
||||||
|
logger.warning("Missing manage_messages permission to delete message %s", message.id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not delete message %s: %s", message.id, exc)
|
||||||
|
|
||||||
|
@bot.command(
|
||||||
|
name="sync-inputs",
|
||||||
|
aliases=["si"],
|
||||||
|
extras={"category": "Files"},
|
||||||
|
help="Scan the comfy-input channel and add 'Set as input' buttons to any untracked images.",
|
||||||
|
)
|
||||||
|
async def sync_inputs_command(ctx: commands.Context) -> None:
|
||||||
|
"""Backfill Set-as-input buttons for images uploaded while the bot was offline."""
|
||||||
|
if config is None:
|
||||||
|
await ctx.reply("Bot config is not available.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = bot.get_channel(config.comfy_input_channel_id)
|
||||||
|
if channel is None:
|
||||||
|
try:
|
||||||
|
channel = await bot.fetch_channel(config.comfy_input_channel_id)
|
||||||
|
except Exception as exc:
|
||||||
|
await ctx.reply(f"❌ Could not access input channel: {exc}", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Track existing records as (message_id, filename) pairs
|
||||||
|
existing = {(row["original_message_id"], row["filename"]) for row in get_all_images()}
|
||||||
|
|
||||||
|
new_count = 0
|
||||||
|
async for message in channel.history(limit=None):
|
||||||
|
if message.author.bot:
|
||||||
|
continue
|
||||||
|
|
||||||
|
had_new = False
|
||||||
|
for attachment in message.attachments:
|
||||||
|
if Path(attachment.filename).suffix.lower() not in IMAGE_EXTENSIONS:
|
||||||
|
continue
|
||||||
|
if (message.id, attachment.filename) in existing:
|
||||||
|
continue
|
||||||
|
|
||||||
|
await _register_attachment(bot, config, message, attachment)
|
||||||
|
existing.add((message.id, attachment.filename))
|
||||||
|
new_count += 1
|
||||||
|
had_new = True
|
||||||
|
|
||||||
|
if had_new:
|
||||||
|
try:
|
||||||
|
await message.delete()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("sync-inputs: could not delete message %s: %s", message.id, exc)
|
||||||
|
|
||||||
|
already = len(get_all_images()) - new_count
|
||||||
|
await ctx.reply(
|
||||||
|
f"Synced {new_count} new image(s). {already} already known.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
370
commands/presets.py
Normal file
370
commands/presets.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
"""
|
||||||
|
commands/presets.py
|
||||||
|
===================
|
||||||
|
|
||||||
|
Named workflow preset commands for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
A preset is a saved snapshot of the current workflow template and runtime
|
||||||
|
state (prompt, negative_prompt, input_image, seed). Presets make it easy
|
||||||
|
to switch between different setups (e.g. "portrait", "landscape", "anime")
|
||||||
|
with a single command.
|
||||||
|
|
||||||
|
All sub-commands are accessed through the single ``ttr!preset`` command:
|
||||||
|
|
||||||
|
ttr!preset save <name> [description:<text>] — capture current workflow + state
|
||||||
|
ttr!preset load <name> — restore workflow + state
|
||||||
|
ttr!preset list — list all saved presets
|
||||||
|
ttr!preset view <name> — show preset details
|
||||||
|
ttr!preset delete <name> — permanently remove a preset
|
||||||
|
ttr!preset save-last <name> [description:<text>] — save last generation as preset
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from discord_utils import require_comfy_client
|
||||||
|
from preset_manager import PresetManager
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_name_and_description(args: str) -> tuple[str, str | None]:
|
||||||
|
"""
|
||||||
|
Split ``<name> [description:<text>]`` into (name, description).
|
||||||
|
|
||||||
|
The name is the first whitespace-delimited token. Everything after
|
||||||
|
``description:`` (case-insensitive) in the remaining text is the
|
||||||
|
description. Returns (name, None) if no description keyword is found.
|
||||||
|
"""
|
||||||
|
parts = args.strip().split(maxsplit=1)
|
||||||
|
name = parts[0] if parts else ""
|
||||||
|
description: str | None = None
|
||||||
|
if len(parts) > 1:
|
||||||
|
rest = parts[1]
|
||||||
|
lower = rest.lower()
|
||||||
|
idx = lower.find("description:")
|
||||||
|
if idx >= 0:
|
||||||
|
description = rest[idx + len("description:"):].strip() or None
|
||||||
|
return name, description
|
||||||
|
|
||||||
|
|
||||||
|
def setup_preset_commands(bot):
|
||||||
|
"""
|
||||||
|
Register preset commands with the bot.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bot : commands.Bot
|
||||||
|
The Discord bot instance.
|
||||||
|
"""
|
||||||
|
preset_manager = PresetManager()
|
||||||
|
|
||||||
|
@bot.command(name="preset", extras={"category": "Presets"})
|
||||||
|
@require_comfy_client
|
||||||
|
async def preset_command(ctx: commands.Context, *, args: str = "") -> None:
|
||||||
|
"""
|
||||||
|
Save, load, list, view, or delete named workflow presets.
|
||||||
|
|
||||||
|
A preset captures the current workflow template and all runtime
|
||||||
|
state changes (prompt, negative_prompt, input_image, seed) under a
|
||||||
|
short name. Load it later to restore everything in one step.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!preset save <name> [description:<text>] — save current workflow + state
|
||||||
|
ttr!preset load <name> — restore workflow + state
|
||||||
|
ttr!preset list — list all saved presets
|
||||||
|
ttr!preset view <name> — show preset details
|
||||||
|
ttr!preset delete <name> — permanently delete a preset
|
||||||
|
ttr!preset save-last <name> [description:<text>] — save last generation as preset
|
||||||
|
|
||||||
|
Names may only contain letters, digits, hyphens, and underscores.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
ttr!preset save portrait description:studio lighting style
|
||||||
|
ttr!preset load portrait
|
||||||
|
ttr!preset list
|
||||||
|
ttr!preset view portrait
|
||||||
|
ttr!preset delete portrait
|
||||||
|
ttr!preset save-last my-last
|
||||||
|
"""
|
||||||
|
parts = args.strip().split(maxsplit=1)
|
||||||
|
subcommand = parts[0].lower() if parts else ""
|
||||||
|
rest = parts[1].strip() if len(parts) > 1 else ""
|
||||||
|
|
||||||
|
if subcommand == "save":
|
||||||
|
name, description = _parse_name_and_description(rest)
|
||||||
|
await _preset_save(ctx, bot, preset_manager, name, description)
|
||||||
|
elif subcommand == "load":
|
||||||
|
await _preset_load(ctx, bot, preset_manager, rest.split()[0] if rest.split() else "")
|
||||||
|
elif subcommand == "list":
|
||||||
|
await _preset_list(ctx, preset_manager)
|
||||||
|
elif subcommand == "view":
|
||||||
|
await _preset_view(ctx, preset_manager, rest.split()[0] if rest.split() else "")
|
||||||
|
elif subcommand == "delete":
|
||||||
|
await _preset_delete(ctx, preset_manager, rest.split()[0] if rest.split() else "")
|
||||||
|
elif subcommand == "save-last":
|
||||||
|
name, description = _parse_name_and_description(rest)
|
||||||
|
await _preset_save_last(ctx, preset_manager, name, description)
|
||||||
|
else:
|
||||||
|
await ctx.reply(
|
||||||
|
"Usage: `ttr!preset <save|load|list|view|delete|save-last> [name]`\n"
|
||||||
|
"Run `ttr!help preset` for full details.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _preset_save(
|
||||||
|
ctx: commands.Context, bot, preset_manager: PresetManager, name: str,
|
||||||
|
description: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle ttr!preset save <name> [description:<text>]."""
|
||||||
|
if not name:
|
||||||
|
await ctx.reply(
|
||||||
|
"Please provide a name. Example: `ttr!preset save portrait`",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not PresetManager.is_valid_name(name):
|
||||||
|
await ctx.reply(
|
||||||
|
"Invalid name. Use only letters, digits, hyphens, and underscores (max 64 chars).",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
workflow_template = bot.comfy.get_workflow_template()
|
||||||
|
state = bot.comfy.get_workflow_current_changes()
|
||||||
|
preset_manager.save(name, workflow_template, state, description=description)
|
||||||
|
|
||||||
|
# Build a summary of what was saved
|
||||||
|
has_workflow = workflow_template is not None
|
||||||
|
state_parts = []
|
||||||
|
if state.get("prompt"):
|
||||||
|
state_parts.append("prompt")
|
||||||
|
if state.get("negative_prompt"):
|
||||||
|
state_parts.append("negative_prompt")
|
||||||
|
if state.get("input_image"):
|
||||||
|
state_parts.append("input_image")
|
||||||
|
if state.get("seed") is not None:
|
||||||
|
state_parts.append(f"seed={state['seed']}")
|
||||||
|
|
||||||
|
summary_parts = []
|
||||||
|
if has_workflow:
|
||||||
|
summary_parts.append("workflow template")
|
||||||
|
summary_parts.extend(state_parts)
|
||||||
|
summary = ", ".join(summary_parts) if summary_parts else "empty state"
|
||||||
|
|
||||||
|
desc_note = f"\n> {description}" if description else ""
|
||||||
|
await ctx.reply(
|
||||||
|
f"Preset **{name}** saved ({summary}).{desc_note}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to save preset '%s'", name)
|
||||||
|
await ctx.reply(
|
||||||
|
f"Failed to save preset: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _preset_load(
|
||||||
|
ctx: commands.Context, bot, preset_manager: PresetManager, name: str
|
||||||
|
) -> None:
|
||||||
|
"""Handle ttr!preset load <name>."""
|
||||||
|
if not name:
|
||||||
|
await ctx.reply(
|
||||||
|
"Please provide a name. Example: `ttr!preset load portrait`",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
data = preset_manager.load(name)
|
||||||
|
if data is None:
|
||||||
|
presets = preset_manager.list_presets()
|
||||||
|
hint = f" Available: {', '.join(presets)}" if presets else " No presets saved yet."
|
||||||
|
await ctx.reply(
|
||||||
|
f"Preset **{name}** not found.{hint}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
restored: list[str] = []
|
||||||
|
|
||||||
|
# Restore workflow template if present
|
||||||
|
workflow = data.get("workflow")
|
||||||
|
if workflow is not None:
|
||||||
|
bot.comfy.set_workflow(workflow)
|
||||||
|
restored.append("workflow template")
|
||||||
|
|
||||||
|
# Restore state changes
|
||||||
|
state = data.get("state", {})
|
||||||
|
if state:
|
||||||
|
bot.comfy.set_workflow_current_changes(state)
|
||||||
|
if state.get("prompt"):
|
||||||
|
restored.append("prompt")
|
||||||
|
if state.get("negative_prompt"):
|
||||||
|
restored.append("negative_prompt")
|
||||||
|
if state.get("input_image"):
|
||||||
|
restored.append("input_image")
|
||||||
|
if state.get("seed") is not None:
|
||||||
|
restored.append(f"seed={state['seed']}")
|
||||||
|
|
||||||
|
summary = ", ".join(restored) if restored else "nothing (preset was empty)"
|
||||||
|
description = data.get("description")
|
||||||
|
desc_note = f"\n> {description}" if description else ""
|
||||||
|
await ctx.reply(
|
||||||
|
f"Preset **{name}** loaded ({summary}).{desc_note}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to load preset '%s'", name)
|
||||||
|
await ctx.reply(
|
||||||
|
f"Failed to load preset: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _preset_view(
|
||||||
|
ctx: commands.Context, preset_manager: PresetManager, name: str
|
||||||
|
) -> None:
|
||||||
|
"""Handle ttr!preset view <name>."""
|
||||||
|
if not name:
|
||||||
|
await ctx.reply(
|
||||||
|
"Please provide a name. Example: `ttr!preset view portrait`",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
data = preset_manager.load(name)
|
||||||
|
if data is None:
|
||||||
|
await ctx.reply(f"Preset **{name}** not found.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
lines = [f"**Preset: {name}**"]
|
||||||
|
if data.get("description"):
|
||||||
|
lines.append(f"> {data['description']}")
|
||||||
|
if data.get("owner"):
|
||||||
|
lines.append(f"Owner: {data['owner']}")
|
||||||
|
|
||||||
|
state = data.get("state", {})
|
||||||
|
if state.get("prompt"):
|
||||||
|
# Truncate long prompts
|
||||||
|
p = str(state["prompt"])
|
||||||
|
if len(p) > 200:
|
||||||
|
p = p[:197] + "…"
|
||||||
|
lines.append(f"**Prompt:** {p}")
|
||||||
|
if state.get("negative_prompt"):
|
||||||
|
np = str(state["negative_prompt"])
|
||||||
|
if len(np) > 100:
|
||||||
|
np = np[:97] + "…"
|
||||||
|
lines.append(f"**Negative:** {np}")
|
||||||
|
if state.get("seed") is not None:
|
||||||
|
seed_note = " (random)" if state["seed"] == -1 else ""
|
||||||
|
lines.append(f"**Seed:** {state['seed']}{seed_note}")
|
||||||
|
|
||||||
|
other = {k: v for k, v in state.items() if k not in ("prompt", "negative_prompt", "seed", "input_image")}
|
||||||
|
if other:
|
||||||
|
other_str = ", ".join(f"{k}={v}" for k, v in other.items())
|
||||||
|
lines.append(f"**Other:** {other_str[:200]}")
|
||||||
|
|
||||||
|
if data.get("workflow") is not None:
|
||||||
|
lines.append("_(includes workflow template)_")
|
||||||
|
else:
|
||||||
|
lines.append("_(no workflow template — load separately)_")
|
||||||
|
|
||||||
|
await ctx.reply("\n".join(lines), mention_author=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def _preset_list(ctx: commands.Context, preset_manager: PresetManager) -> None:
|
||||||
|
"""Handle ttr!preset list."""
|
||||||
|
presets = preset_manager.list_preset_details()
|
||||||
|
if not presets:
|
||||||
|
await ctx.reply(
|
||||||
|
"No presets saved yet. Use `ttr!preset save <name>` to create one.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
lines = [f"**Saved presets** ({len(presets)})"]
|
||||||
|
for p in presets:
|
||||||
|
entry = f" • {p['name']}"
|
||||||
|
if p.get("description"):
|
||||||
|
entry += f" — {p['description']}"
|
||||||
|
lines.append(entry)
|
||||||
|
lines.append("\nUse `ttr!preset load <name>` to restore one.")
|
||||||
|
await ctx.reply("\n".join(lines), mention_author=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def _preset_delete(
|
||||||
|
ctx: commands.Context, preset_manager: PresetManager, name: str
|
||||||
|
) -> None:
|
||||||
|
"""Handle ttr!preset delete <name>."""
|
||||||
|
if not name:
|
||||||
|
await ctx.reply(
|
||||||
|
"Please provide a name. Example: `ttr!preset delete portrait`",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
deleted = preset_manager.delete(name)
|
||||||
|
if deleted:
|
||||||
|
await ctx.reply(f"Preset **{name}** deleted.", mention_author=False)
|
||||||
|
else:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Preset **{name}** not found.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _preset_save_last(
|
||||||
|
ctx: commands.Context, preset_manager: PresetManager, name: str,
|
||||||
|
description: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle ttr!preset save-last <name> [description:<text>]."""
|
||||||
|
if not name:
|
||||||
|
await ctx.reply(
|
||||||
|
"Please provide a name. Example: `ttr!preset save-last my-last`",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not PresetManager.is_valid_name(name):
|
||||||
|
await ctx.reply(
|
||||||
|
"Invalid name. Use only letters, digits, hyphens, and underscores (max 64 chars).",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
from generation_db import get_history as db_get_history
|
||||||
|
history = db_get_history(limit=1)
|
||||||
|
if not history:
|
||||||
|
await ctx.reply(
|
||||||
|
"No generation history found. Generate something first!",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
last = history[0]
|
||||||
|
overrides = last.get("overrides") or {}
|
||||||
|
try:
|
||||||
|
preset_manager.save(name, None, overrides, description=description)
|
||||||
|
desc_note = f"\n> {description}" if description else ""
|
||||||
|
await ctx.reply(
|
||||||
|
f"Preset **{name}** saved from last generation.{desc_note}\n"
|
||||||
|
"Note: workflow template not included — load it separately before generating.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
await ctx.reply(str(exc), mention_author=False)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to save preset '%s' from history", name)
|
||||||
|
await ctx.reply(
|
||||||
|
f"Failed to save preset: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
484
commands/server.py
Normal file
484
commands/server.py
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
"""
|
||||||
|
commands/server.py
|
||||||
|
==================
|
||||||
|
|
||||||
|
ComfyUI server lifecycle management via NSSM Windows service.
|
||||||
|
|
||||||
|
On bot startup, `autostart_comfy()` runs as a background task:
|
||||||
|
1. If the service does not exist, it is installed automatically.
|
||||||
|
2. If the service exists but ComfyUI is not responding, it is started.
|
||||||
|
|
||||||
|
NSSM handles:
|
||||||
|
- Background process management (no console window)
|
||||||
|
- Stdout / stderr capture to rotating log files
|
||||||
|
- Complete isolation from the bot's own NSSM service
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
ttr!server start — start the service
|
||||||
|
ttr!server stop — stop the service
|
||||||
|
ttr!server restart — restart the service
|
||||||
|
ttr!server status — NSSM service state + HTTP reachability
|
||||||
|
ttr!server install — (re)install / reconfigure the NSSM service
|
||||||
|
ttr!server uninstall — remove the service from Windows
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
- nssm.exe in PATH
|
||||||
|
- The bot service account must have permission to manage Windows services
|
||||||
|
(Local System or a user with SeServiceLogonRight works)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_POLL_INTERVAL = 5 # seconds between HTTP up-checks
|
||||||
|
_MAX_ATTEMPTS = 24 # 24 × 5s = 120s max wait
|
||||||
|
|
||||||
|
# Public — imported by status_monitor for emoji rendering
|
||||||
|
STATUS_EMOJI: dict[str, str] = {
|
||||||
|
"SERVICE_RUNNING": "🟢",
|
||||||
|
"SERVICE_STOPPED": "🔴",
|
||||||
|
"SERVICE_PAUSED": "🟡",
|
||||||
|
"SERVICE_START_PENDING": "⏳",
|
||||||
|
"SERVICE_STOP_PENDING": "⏳",
|
||||||
|
"SERVICE_PAUSE_PENDING": "⏳",
|
||||||
|
"SERVICE_CONTINUE_PENDING": "⏳",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Low-level subprocess helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _nssm(*args: str) -> tuple[int, str]:
|
||||||
|
"""Run `nssm <args>` and return (returncode, stdout)."""
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"nssm", *args,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=30)
|
||||||
|
return proc.returncode, stdout.decode(errors="replace").strip()
|
||||||
|
except FileNotFoundError:
|
||||||
|
return -1, "nssm not found — is it installed and in PATH?"
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return -1, "nssm command timed out."
|
||||||
|
except Exception as exc:
|
||||||
|
return -1, str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_service_pid(service_name: str) -> int:
|
||||||
|
"""Return the PID of the process backing *service_name*, or 0 if unavailable."""
|
||||||
|
rc, out = await _nssm("getpid", service_name)
|
||||||
|
if rc != 0:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
return int(out.strip())
|
||||||
|
except ValueError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _kill_service_process(service_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Forcefully kill the process backing *service_name*.
|
||||||
|
|
||||||
|
NSSM does not have a `kill` subcommand. Instead we retrieve the PID
|
||||||
|
via `nssm getpid` and then use `taskkill /F /PID`. Safe to call when
|
||||||
|
the service is already stopped (no-op if PID is 0).
|
||||||
|
"""
|
||||||
|
pid = await _get_service_pid(service_name)
|
||||||
|
if not pid:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"taskkill", "/F", "/PID", str(pid),
|
||||||
|
stdout=asyncio.subprocess.DEVNULL,
|
||||||
|
stderr=asyncio.subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
await asyncio.wait_for(proc.communicate(), timeout=10)
|
||||||
|
logger.debug("taskkill /F /PID %d sent for service '%s'", pid, service_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("taskkill failed for PID %d (%s): %s", pid, service_name, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _is_comfy_up(server_address: str, timeout: float = 3.0) -> bool:
|
||||||
|
"""Return True if the ComfyUI HTTP endpoint is responding."""
|
||||||
|
url = f"http://{server_address}/system_stats"
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=timeout)) as resp:
|
||||||
|
return resp.status == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _service_exists(service_name: str) -> bool:
|
||||||
|
"""Return True if the Windows service is installed (running or stopped)."""
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"sc", "query", service_name,
|
||||||
|
stdout=asyncio.subprocess.DEVNULL,
|
||||||
|
stderr=asyncio.subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
await proc.communicate()
|
||||||
|
return proc.returncode == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API — used by status_monitor and other modules
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get_service_state(service_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Return the NSSM service state string for *service_name*.
|
||||||
|
|
||||||
|
Returns one of the SERVICE_* keys in STATUS_EMOJI on success, or
|
||||||
|
"error" / "timeout" / "unknown" on failure. Intended for use by
|
||||||
|
the status dashboard — callers should not raise on these sentinel values.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
rc, out = await asyncio.wait_for(_nssm("status", service_name), timeout=5.0)
|
||||||
|
if rc == -1:
|
||||||
|
return "error"
|
||||||
|
return out.strip() or "unknown"
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return "timeout"
|
||||||
|
except Exception:
|
||||||
|
return "error"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Service installation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _install_service(config) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Install the ComfyUI NSSM service with log capture and rotation.
|
||||||
|
|
||||||
|
We install directly via python.exe (not the .bat file) to avoid the
|
||||||
|
"Terminate batch job (Y/N)?" prompt that can cause NSSM to hang on STOP.
|
||||||
|
|
||||||
|
Safe to call even if the service already exists — it will be removed first.
|
||||||
|
Returns (success, message).
|
||||||
|
"""
|
||||||
|
name = config.comfy_service_name
|
||||||
|
start_bat = Path(config.comfy_start_bat)
|
||||||
|
log_dir = Path(config.comfy_log_dir)
|
||||||
|
log_file = str(log_dir / "comfyui.log")
|
||||||
|
max_bytes = str(config.comfy_log_max_mb * 1024 * 1024)
|
||||||
|
|
||||||
|
# Derive portable paths from the .bat location (ComfyUI_windows_portable root):
|
||||||
|
# <root>/run_nvidia_gpu.bat
|
||||||
|
# <root>/python_embeded/python.exe
|
||||||
|
# <root>/ComfyUI/main.py
|
||||||
|
portable_root = start_bat.parent
|
||||||
|
python_exe = portable_root / "python_embeded" / "python.exe"
|
||||||
|
main_py = portable_root / "ComfyUI" / "main.py"
|
||||||
|
|
||||||
|
if not start_bat.exists():
|
||||||
|
return False, f"Start bat not found (used to derive paths): `{start_bat}`"
|
||||||
|
if not python_exe.exists():
|
||||||
|
return False, f"Portable python not found: `{python_exe}`"
|
||||||
|
if not main_py.exists():
|
||||||
|
return False, f"ComfyUI main.py not found: `{main_py}`"
|
||||||
|
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Optional extra args from config (accepts string or list/tuple)
|
||||||
|
extra_args: list[str] = []
|
||||||
|
extra = getattr(config, "comfy_extra_args", None)
|
||||||
|
try:
|
||||||
|
if isinstance(extra, (list, tuple)):
|
||||||
|
extra_args = [str(x) for x in extra if str(x).strip()]
|
||||||
|
elif isinstance(extra, str) and extra.strip():
|
||||||
|
import shlex
|
||||||
|
extra_args = shlex.split(extra)
|
||||||
|
except Exception:
|
||||||
|
extra_args = [] # ignore parse errors rather than aborting install
|
||||||
|
|
||||||
|
# Remove any existing service cleanly before reinstalling
|
||||||
|
if await _service_exists(name):
|
||||||
|
await _nssm("stop", name)
|
||||||
|
await _kill_service_process(name) # force-kill if stuck in STOP_PENDING
|
||||||
|
rc, out = await _nssm("remove", name, "confirm")
|
||||||
|
if rc != 0:
|
||||||
|
return False, f"Could not remove existing service: {out}"
|
||||||
|
|
||||||
|
# nssm install <name> <python.exe> -s <main.py> --windows-standalone-build [extra]
|
||||||
|
steps: list[tuple[str, ...]] = [
|
||||||
|
("install", name, str(python_exe), "-s", str(main_py), "--windows-standalone-build", *extra_args),
|
||||||
|
("set", name, "AppDirectory", str(portable_root)),
|
||||||
|
("set", name, "DisplayName", "ComfyUI Server"),
|
||||||
|
("set", name, "AppStdout", log_file),
|
||||||
|
("set", name, "AppStderr", log_file),
|
||||||
|
("set", name, "AppRotateFiles", "1"),
|
||||||
|
("set", name, "AppRotateBytes", max_bytes),
|
||||||
|
("set", name, "AppRotateOnline", "1"),
|
||||||
|
("set", name, "Start", "SERVICE_DEMAND_START"),
|
||||||
|
# Stop behavior — prevent NSSM from hanging indefinitely
|
||||||
|
("set", name, "AppKillProcessTree", "1"),
|
||||||
|
("set", name, "AppStopMethodConsole", "1500"),
|
||||||
|
("set", name, "AppStopMethodWindow", "1500"),
|
||||||
|
("set", name, "AppStopMethodThreads", "1500"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for step in steps:
|
||||||
|
rc, out = await _nssm(*step)
|
||||||
|
if rc != 0:
|
||||||
|
return False, f"`nssm {' '.join(step[:3])}` failed: {out}"
|
||||||
|
|
||||||
|
return True, f"Service `{name}` installed. Log: `{log_file}`"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Autostart (called from bot.py on_ready)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def autostart_comfy(config) -> None:
|
||||||
|
"""
|
||||||
|
Ensure ComfyUI is running when the bot starts.
|
||||||
|
|
||||||
|
1. Install the NSSM service if it is missing.
|
||||||
|
2. Start the service if ComfyUI is not already responding.
|
||||||
|
|
||||||
|
Does nothing if config.comfy_autostart is False.
|
||||||
|
"""
|
||||||
|
if not getattr(config, "comfy_autostart", True):
|
||||||
|
return
|
||||||
|
|
||||||
|
if not await _service_exists(config.comfy_service_name):
|
||||||
|
logger.info("NSSM service '%s' not found — installing", config.comfy_service_name)
|
||||||
|
ok, msg = await _install_service(config)
|
||||||
|
if not ok:
|
||||||
|
logger.error("Failed to install ComfyUI service: %s", msg)
|
||||||
|
return
|
||||||
|
logger.info("ComfyUI service installed: %s", msg)
|
||||||
|
|
||||||
|
if await _is_comfy_up(config.comfy_server):
|
||||||
|
logger.info("ComfyUI already running at %s", config.comfy_server)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Starting NSSM service '%s'", config.comfy_service_name)
|
||||||
|
rc, out = await _nssm("start", config.comfy_service_name)
|
||||||
|
if rc != 0:
|
||||||
|
logger.warning("nssm start returned %d: %s", rc, out)
|
||||||
|
return
|
||||||
|
|
||||||
|
for attempt in range(_MAX_ATTEMPTS):
|
||||||
|
await asyncio.sleep(_POLL_INTERVAL)
|
||||||
|
if await _is_comfy_up(config.comfy_server):
|
||||||
|
logger.info("ComfyUI is up after ~%ds", (attempt + 1) * _POLL_INTERVAL)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"ComfyUI did not respond within %ds after service start",
|
||||||
|
_MAX_ATTEMPTS * _POLL_INTERVAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Discord commands
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def setup_server_commands(bot, config=None):
|
||||||
|
"""Register ComfyUI server management commands."""
|
||||||
|
|
||||||
|
def _no_config(ctx):
|
||||||
|
"""Reply and return True when config is missing (guards every subcommand)."""
|
||||||
|
return config is None
|
||||||
|
|
||||||
|
@bot.group(name="server", invoke_without_command=True, extras={"category": "Server"})
|
||||||
|
async def server_group(ctx: commands.Context) -> None:
|
||||||
|
"""ComfyUI server management. Subcommands: start, stop, restart, status, install, uninstall."""
|
||||||
|
await ctx.send_help(ctx.command)
|
||||||
|
|
||||||
|
@server_group.command(name="start")
|
||||||
|
async def server_start(ctx: commands.Context) -> None:
|
||||||
|
"""Start the ComfyUI service."""
|
||||||
|
if config is None:
|
||||||
|
await ctx.reply("Bot config not available.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
if await _is_comfy_up(config.comfy_server):
|
||||||
|
await ctx.reply("✅ ComfyUI is already running.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = await ctx.reply(
|
||||||
|
f"⏳ Starting service `{config.comfy_service_name}`…", mention_author=False
|
||||||
|
)
|
||||||
|
rc, out = await _nssm("start", config.comfy_service_name)
|
||||||
|
if rc != 0:
|
||||||
|
await msg.edit(content=f"❌ `{out}`")
|
||||||
|
return
|
||||||
|
|
||||||
|
await msg.edit(content="⏳ Waiting for ComfyUI to respond…")
|
||||||
|
for attempt in range(_MAX_ATTEMPTS):
|
||||||
|
await asyncio.sleep(_POLL_INTERVAL)
|
||||||
|
if await _is_comfy_up(config.comfy_server):
|
||||||
|
await msg.edit(
|
||||||
|
content=f"✅ ComfyUI is up! (took ~{(attempt + 1) * _POLL_INTERVAL}s)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await msg.edit(content="⚠️ Service started but ComfyUI did not respond within 120 seconds.")
|
||||||
|
|
||||||
|
@server_group.command(name="stop")
|
||||||
|
async def server_stop(ctx: commands.Context) -> None:
|
||||||
|
"""Stop the ComfyUI service (force-kills if graceful stop fails)."""
|
||||||
|
if config is None:
|
||||||
|
await ctx.reply("Bot config not available.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = await ctx.reply(
|
||||||
|
f"⏳ Stopping service `{config.comfy_service_name}`…", mention_author=False
|
||||||
|
)
|
||||||
|
rc, out = await _nssm("stop", config.comfy_service_name)
|
||||||
|
if rc == 0:
|
||||||
|
await msg.edit(content="✅ ComfyUI service stopped.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Graceful stop failed (timed out or error) — force-kill the process.
|
||||||
|
await msg.edit(content="⏳ Graceful stop failed — force-killing process…")
|
||||||
|
await _kill_service_process(config.comfy_service_name)
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
state = await get_service_state(config.comfy_service_name)
|
||||||
|
if state == "SERVICE_STOPPED":
|
||||||
|
await msg.edit(content="✅ ComfyUI service force-killed and stopped.")
|
||||||
|
else:
|
||||||
|
await msg.edit(
|
||||||
|
content=f"⚠️ Force-kill sent but service state is `{state}`. "
|
||||||
|
f"Use `ttr!server kill` to try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
@server_group.command(name="kill")
|
||||||
|
async def server_kill(ctx: commands.Context) -> None:
|
||||||
|
"""Force-kill the ComfyUI process when it is stuck in STOPPING/STOP_PENDING."""
|
||||||
|
if config is None:
|
||||||
|
await ctx.reply("Bot config not available.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = await ctx.reply(
|
||||||
|
f"⏳ Force-killing `{config.comfy_service_name}` process…", mention_author=False
|
||||||
|
)
|
||||||
|
await _kill_service_process(config.comfy_service_name)
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
state = await get_service_state(config.comfy_service_name)
|
||||||
|
emoji = STATUS_EMOJI.get(state, "⚪")
|
||||||
|
await msg.edit(
|
||||||
|
content=f"💀 taskkill sent. Service state is now {emoji} `{state}`."
|
||||||
|
)
|
||||||
|
|
||||||
|
@server_group.command(name="restart")
|
||||||
|
async def server_restart(ctx: commands.Context) -> None:
|
||||||
|
"""Restart the ComfyUI service (force-kills if graceful stop fails)."""
|
||||||
|
if config is None:
|
||||||
|
await ctx.reply("Bot config not available.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = await ctx.reply(
|
||||||
|
f"⏳ Stopping `{config.comfy_service_name}` for restart…", mention_author=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: graceful stop.
|
||||||
|
rc, out = await _nssm("stop", config.comfy_service_name)
|
||||||
|
if rc != 0:
|
||||||
|
# Stop timed out or failed — force-kill so we can start fresh.
|
||||||
|
await msg.edit(content="⏳ Graceful stop failed — force-killing process…")
|
||||||
|
await _kill_service_process(config.comfy_service_name)
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Step 2: verify stopped before starting.
|
||||||
|
state = await get_service_state(config.comfy_service_name)
|
||||||
|
if state not in ("SERVICE_STOPPED", "error", "unknown", "timeout"):
|
||||||
|
# Still not fully stopped — try one more force-kill.
|
||||||
|
await _kill_service_process(config.comfy_service_name)
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Step 3: start.
|
||||||
|
await msg.edit(content=f"⏳ Starting `{config.comfy_service_name}`…")
|
||||||
|
rc, out = await _nssm("start", config.comfy_service_name)
|
||||||
|
if rc != 0:
|
||||||
|
await msg.edit(content=f"❌ Start failed: `{out}`")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 4: wait for HTTP.
|
||||||
|
await msg.edit(content="⏳ Waiting for ComfyUI to come back up…")
|
||||||
|
for attempt in range(_MAX_ATTEMPTS):
|
||||||
|
await asyncio.sleep(_POLL_INTERVAL)
|
||||||
|
if await _is_comfy_up(config.comfy_server):
|
||||||
|
await msg.edit(
|
||||||
|
content=f"✅ ComfyUI is back up! (took ~{(attempt + 1) * _POLL_INTERVAL}s)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await msg.edit(content="⚠️ Service started but ComfyUI did not respond within 120 seconds.")
|
||||||
|
|
||||||
|
@server_group.command(name="status")
|
||||||
|
async def server_status(ctx: commands.Context) -> None:
|
||||||
|
"""Show NSSM service state and HTTP reachability."""
|
||||||
|
if config is None:
|
||||||
|
await ctx.reply("Bot config not available.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
state, http_up = await asyncio.gather(
|
||||||
|
get_service_state(config.comfy_service_name),
|
||||||
|
_is_comfy_up(config.comfy_server),
|
||||||
|
)
|
||||||
|
|
||||||
|
emoji = STATUS_EMOJI.get(state, "⚪")
|
||||||
|
svc_line = f"{emoji} `{state}`"
|
||||||
|
http_line = (
|
||||||
|
f"🟢 Responding at `{config.comfy_server}`"
|
||||||
|
if http_up else
|
||||||
|
f"🔴 Not responding at `{config.comfy_server}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.reply(
|
||||||
|
f"**ComfyUI Server Status**\n"
|
||||||
|
f"Service `{config.comfy_service_name}`: {svc_line}\n"
|
||||||
|
f"HTTP: {http_line}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@server_group.command(name="install")
|
||||||
|
async def server_install(ctx: commands.Context) -> None:
|
||||||
|
"""(Re)install the ComfyUI NSSM service with current config settings."""
|
||||||
|
if config is None:
|
||||||
|
await ctx.reply("Bot config not available.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = await ctx.reply(
|
||||||
|
f"⏳ Installing service `{config.comfy_service_name}`…", mention_author=False
|
||||||
|
)
|
||||||
|
ok, detail = await _install_service(config)
|
||||||
|
await msg.edit(content=f"{'✅' if ok else '❌'} {detail}")
|
||||||
|
|
||||||
|
@server_group.command(name="uninstall")
|
||||||
|
async def server_uninstall(ctx: commands.Context) -> None:
|
||||||
|
"""Stop and remove the ComfyUI NSSM service from Windows."""
|
||||||
|
if config is None:
|
||||||
|
await ctx.reply("Bot config not available.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = await ctx.reply(
|
||||||
|
f"⏳ Removing service `{config.comfy_service_name}`…", mention_author=False
|
||||||
|
)
|
||||||
|
await _nssm("stop", config.comfy_service_name)
|
||||||
|
await _kill_service_process(config.comfy_service_name)
|
||||||
|
rc, out = await _nssm("remove", config.comfy_service_name, "confirm")
|
||||||
|
if rc == 0:
|
||||||
|
await msg.edit(content=f"✅ Service `{config.comfy_service_name}` removed.")
|
||||||
|
else:
|
||||||
|
await msg.edit(content=f"❌ `{out}`")
|
||||||
268
commands/utility.py
Normal file
268
commands/utility.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""
|
||||||
|
commands/utility.py
|
||||||
|
===================
|
||||||
|
|
||||||
|
Quality-of-life utility commands for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
Commands provided:
|
||||||
|
- ping: Show bot latency (Discord WebSocket round-trip).
|
||||||
|
- status: Full overview of bot health, ComfyUI connectivity,
|
||||||
|
workflow state, and queue.
|
||||||
|
- queue-status: Quick view of pending job count and worker state.
|
||||||
|
- uptime: How long the bot has been running since it connected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_uptime(start_time: datetime) -> str:
|
||||||
|
"""Return a human-readable uptime string from a UTC start datetime."""
|
||||||
|
delta = datetime.now(timezone.utc) - start_time
|
||||||
|
total_seconds = int(delta.total_seconds())
|
||||||
|
days, remainder = divmod(total_seconds, 86400)
|
||||||
|
hours, remainder = divmod(remainder, 3600)
|
||||||
|
minutes, seconds = divmod(remainder, 60)
|
||||||
|
if days:
|
||||||
|
return f"{days}d {hours}h {minutes}m {seconds}s"
|
||||||
|
if hours:
|
||||||
|
return f"{hours}h {minutes}m {seconds}s"
|
||||||
|
return f"{minutes}m {seconds}s"
|
||||||
|
|
||||||
|
|
||||||
|
def setup_utility_commands(bot):
|
||||||
|
"""
|
||||||
|
Register quality-of-life utility commands with the bot.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bot : commands.Bot
|
||||||
|
The Discord bot instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@bot.command(name="ping", extras={"category": "Utility"})
|
||||||
|
async def ping_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Show the bot's current Discord WebSocket latency.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!ping
|
||||||
|
"""
|
||||||
|
latency_ms = round(bot.latency * 1000)
|
||||||
|
await ctx.reply(f"Pong! Latency: **{latency_ms} ms**", mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(name="status", extras={"category": "Utility"})
|
||||||
|
async def status_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Show a full health overview of the bot and ComfyUI.
|
||||||
|
|
||||||
|
Displays:
|
||||||
|
- Bot latency and uptime
|
||||||
|
- ComfyUI server address and reachability
|
||||||
|
- Whether a workflow template is loaded
|
||||||
|
- Current workflow changes (prompt / negative_prompt / input_image)
|
||||||
|
- Job queue size and worker state
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!status
|
||||||
|
"""
|
||||||
|
latency_ms = round(bot.latency * 1000)
|
||||||
|
|
||||||
|
# Uptime
|
||||||
|
if hasattr(bot, "start_time") and bot.start_time:
|
||||||
|
uptime_str = _format_uptime(bot.start_time)
|
||||||
|
else:
|
||||||
|
uptime_str = "N/A"
|
||||||
|
|
||||||
|
# ComfyUI info
|
||||||
|
comfy_ok = hasattr(bot, "comfy") and bot.comfy is not None
|
||||||
|
comfy_server = bot.comfy.server_address if comfy_ok else "not configured"
|
||||||
|
comfy_reachable = await bot.comfy.check_connection() if comfy_ok else False
|
||||||
|
workflow_loaded = comfy_ok and bot.comfy.get_workflow_template() is not None
|
||||||
|
|
||||||
|
# ComfyUI queue
|
||||||
|
comfy_pending = 0
|
||||||
|
comfy_running = 0
|
||||||
|
if comfy_ok:
|
||||||
|
q = await bot.comfy.get_comfy_queue()
|
||||||
|
if q:
|
||||||
|
comfy_pending = len(q.get("queue_pending", []))
|
||||||
|
comfy_running = len(q.get("queue_running", []))
|
||||||
|
|
||||||
|
# Workflow state summary
|
||||||
|
changes_parts: list[str] = []
|
||||||
|
if comfy_ok:
|
||||||
|
overrides = bot.comfy.state_manager.get_overrides()
|
||||||
|
if overrides.get("prompt"):
|
||||||
|
changes_parts.append("prompt")
|
||||||
|
if overrides.get("negative_prompt"):
|
||||||
|
changes_parts.append("negative_prompt")
|
||||||
|
if overrides.get("input_image"):
|
||||||
|
changes_parts.append(f"input_image: {overrides['input_image']}")
|
||||||
|
if overrides.get("seed") is not None:
|
||||||
|
changes_parts.append(f"seed={overrides['seed']}")
|
||||||
|
changes_summary = ", ".join(changes_parts) if changes_parts else "none"
|
||||||
|
|
||||||
|
conn_status = (
|
||||||
|
"reachable" if comfy_reachable
|
||||||
|
else ("unreachable" if comfy_ok else "not configured")
|
||||||
|
)
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"**Bot**",
|
||||||
|
f" Latency : {latency_ms} ms",
|
||||||
|
f" Uptime : {uptime_str}",
|
||||||
|
"",
|
||||||
|
f"**ComfyUI** — `{comfy_server}`",
|
||||||
|
f" Connection : {conn_status}",
|
||||||
|
f" Queue : {comfy_running} running, {comfy_pending} pending",
|
||||||
|
f" Workflow : {'loaded' if workflow_loaded else 'not loaded'}",
|
||||||
|
f" Changes set : {changes_summary}",
|
||||||
|
]
|
||||||
|
await ctx.reply("\n".join(lines), mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(name="queue-status", aliases=["qs", "qstatus"], extras={"category": "Utility"})
|
||||||
|
async def queue_status_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Show the current ComfyUI queue depth.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!queue-status
|
||||||
|
ttr!qs
|
||||||
|
"""
|
||||||
|
if not hasattr(bot, "comfy") or not bot.comfy:
|
||||||
|
await ctx.reply("ComfyUI client is not configured.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
q = await bot.comfy.get_comfy_queue()
|
||||||
|
if q is None:
|
||||||
|
await ctx.reply("Could not reach ComfyUI server.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
pending = len(q.get("queue_pending", []))
|
||||||
|
running = len(q.get("queue_running", []))
|
||||||
|
await ctx.reply(
|
||||||
|
f"ComfyUI queue: **{running}** running, **{pending}** pending.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@bot.command(name="uptime", extras={"category": "Utility"})
|
||||||
|
async def uptime_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Show how long the bot has been running since it last connected.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!uptime
|
||||||
|
"""
|
||||||
|
if not hasattr(bot, "start_time") or not bot.start_time:
|
||||||
|
await ctx.reply("Uptime information is not available.", mention_author=False)
|
||||||
|
return
|
||||||
|
uptime_str = _format_uptime(bot.start_time)
|
||||||
|
await ctx.reply(f"Uptime: **{uptime_str}**", mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(name="comfy-stats", aliases=["cstats"], extras={"category": "Utility"})
|
||||||
|
async def comfy_stats_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Show GPU and system stats from the ComfyUI server.
|
||||||
|
|
||||||
|
Displays OS, Python version, and per-device VRAM usage reported
|
||||||
|
by the ComfyUI ``/system_stats`` endpoint.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!comfy-stats
|
||||||
|
ttr!cstats
|
||||||
|
"""
|
||||||
|
if not hasattr(bot, "comfy") or not bot.comfy:
|
||||||
|
await ctx.reply("ComfyUI client is not configured.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
stats = await bot.comfy.get_system_stats()
|
||||||
|
if stats is None:
|
||||||
|
await ctx.reply(
|
||||||
|
"Could not reach the ComfyUI server to fetch stats.", mention_author=False
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
system = stats.get("system", {})
|
||||||
|
devices = stats.get("devices", [])
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"**ComfyUI System Stats** — `{bot.comfy.server_address}`",
|
||||||
|
f" OS : {system.get('os', 'N/A')}",
|
||||||
|
f" Python : {system.get('python_version', 'N/A')}",
|
||||||
|
]
|
||||||
|
|
||||||
|
if devices:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("**Devices**")
|
||||||
|
for dev in devices:
|
||||||
|
name = dev.get("name", "unknown")
|
||||||
|
vram_total = dev.get("vram_total", 0)
|
||||||
|
vram_free = dev.get("vram_free", 0)
|
||||||
|
vram_used = vram_total - vram_free
|
||||||
|
|
||||||
|
def _mb(b: int) -> str:
|
||||||
|
return f"{b / 1024 / 1024:.0f} MB"
|
||||||
|
|
||||||
|
lines.append(
|
||||||
|
f" {name} — {_mb(vram_used)} / {_mb(vram_total)} VRAM used"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lines.append(" No device info available.")
|
||||||
|
|
||||||
|
await ctx.reply("\n".join(lines), mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(name="comfy-queue", aliases=["cqueue", "cq"], extras={"category": "Utility"})
|
||||||
|
async def comfy_queue_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Show the ComfyUI server's internal queue state.
|
||||||
|
|
||||||
|
Displays jobs currently running and pending on the ComfyUI server
|
||||||
|
itself (separate from the Discord bot's own job queue).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!comfy-queue
|
||||||
|
ttr!cq
|
||||||
|
"""
|
||||||
|
if not hasattr(bot, "comfy") or not bot.comfy:
|
||||||
|
await ctx.reply("ComfyUI client is not configured.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
queue_data = await bot.comfy.get_comfy_queue()
|
||||||
|
if queue_data is None:
|
||||||
|
await ctx.reply(
|
||||||
|
"Could not reach the ComfyUI server to fetch queue info.", mention_author=False
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
running = queue_data.get("queue_running", [])
|
||||||
|
pending = queue_data.get("queue_pending", [])
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"**ComfyUI Server Queue** — `{bot.comfy.server_address}`",
|
||||||
|
f" Running : {len(running)} job(s)",
|
||||||
|
f" Pending : {len(pending)} job(s)",
|
||||||
|
]
|
||||||
|
|
||||||
|
if running:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("**Currently running**")
|
||||||
|
for entry in running[:5]: # cap at 5 to avoid huge messages
|
||||||
|
prompt_id = entry[1] if len(entry) > 1 else "unknown"
|
||||||
|
lines.append(f" `{prompt_id}`")
|
||||||
|
|
||||||
|
if pending:
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"**Pending** (showing up to 5 of {len(pending)})")
|
||||||
|
for entry in pending[:5]:
|
||||||
|
prompt_id = entry[1] if len(entry) > 1 else "unknown"
|
||||||
|
lines.append(f" `{prompt_id}`")
|
||||||
|
|
||||||
|
await ctx.reply("\n".join(lines), mention_author=False)
|
||||||
100
commands/workflow.py
Normal file
100
commands/workflow.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""
|
||||||
|
commands/workflow.py
|
||||||
|
====================
|
||||||
|
|
||||||
|
Workflow management commands for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
This module contains commands for loading and managing workflow templates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from discord_utils import require_comfy_client
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_workflow_commands(bot):
|
||||||
|
"""
|
||||||
|
Register workflow management commands with the bot.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bot : commands.Bot
|
||||||
|
The Discord bot instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@bot.command(name="workflow-load", aliases=["workflowload", "wfl"], extras={"category": "Workflow"})
|
||||||
|
@require_comfy_client
|
||||||
|
async def load_workflow_command(ctx: commands.Context, *, path: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Load a ComfyUI workflow from a JSON file.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ttr!workflow-load path/to/workflow.json
|
||||||
|
|
||||||
|
You can also attach a JSON file to the command message instead of
|
||||||
|
providing a path. The loaded workflow will replace the current
|
||||||
|
workflow template used by the bot. After loading a workflow you
|
||||||
|
can generate images with your prompts while reusing the loaded
|
||||||
|
graph structure.
|
||||||
|
"""
|
||||||
|
workflow_data: Optional[Dict] = None
|
||||||
|
|
||||||
|
# Check for attached JSON file first
|
||||||
|
for attachment in ctx.message.attachments:
|
||||||
|
if attachment.filename.lower().endswith(".json"):
|
||||||
|
raw = await attachment.read()
|
||||||
|
try:
|
||||||
|
text = raw.decode("utf-8")
|
||||||
|
except UnicodeDecodeError as exc:
|
||||||
|
await ctx.reply(
|
||||||
|
f"`{attachment.filename}` is not valid UTF-8: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
workflow_data = json.loads(text)
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Failed to parse `{attachment.filename}` as JSON: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise try to load from provided path
|
||||||
|
if workflow_data is None and path:
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
workflow_data = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
await ctx.reply(f"File not found: `{path}`", mention_author=False)
|
||||||
|
return
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
await ctx.reply(f"Invalid JSON in `{path}`: {exc}", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
if workflow_data is None:
|
||||||
|
await ctx.reply(
|
||||||
|
"Please provide a JSON workflow file either as an attachment or a path.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Set the workflow on the client
|
||||||
|
try:
|
||||||
|
bot.comfy.set_workflow(workflow_data)
|
||||||
|
await ctx.reply("Workflow loaded successfully.", mention_author=False)
|
||||||
|
except Exception as exc:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Failed to set workflow: {type(exc).__name__}: {exc}",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
252
commands/workflow_changes.py
Normal file
252
commands/workflow_changes.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""
|
||||||
|
commands/workflow_changes.py
|
||||||
|
============================
|
||||||
|
|
||||||
|
Workflow override management commands for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
Works with any NodeInput.key discovered by WorkflowInspector — not just
|
||||||
|
the four original hard-coded keys. Backward-compat aliases are preserved:
|
||||||
|
``type:prompt``, ``type:negative_prompt``, ``type:input_image``,
|
||||||
|
``type:seed``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from config import ARG_TYPE_KEY
|
||||||
|
from discord_utils import require_comfy_client
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_workflow_changes_commands(bot):
|
||||||
|
"""Register workflow changes commands with the bot."""
|
||||||
|
|
||||||
|
@bot.command(
|
||||||
|
name="get-current-workflow-changes",
|
||||||
|
aliases=["getworkflowchanges", "gcwc"],
|
||||||
|
extras={"category": "Workflow"},
|
||||||
|
)
|
||||||
|
@require_comfy_client
|
||||||
|
async def get_current_workflow_changes_command(
|
||||||
|
ctx: commands.Context, *, args: str = ""
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Show current workflow override values.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
ttr!get-current-workflow-changes type:all
|
||||||
|
ttr!get-current-workflow-changes type:prompt
|
||||||
|
ttr!get-current-workflow-changes type:<any_override_key>
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
overrides = bot.comfy.state_manager.get_overrides()
|
||||||
|
|
||||||
|
if ARG_TYPE_KEY not in args:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Use `{ARG_TYPE_KEY}all` to see all overrides, or "
|
||||||
|
f"`{ARG_TYPE_KEY}<key>` for a specific key.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
param = args.split(ARG_TYPE_KEY, 1)[1].strip().lower()
|
||||||
|
|
||||||
|
if param == "all":
|
||||||
|
if not overrides:
|
||||||
|
await ctx.reply("No overrides set.", mention_author=False)
|
||||||
|
return
|
||||||
|
lines = [f"**{k}**: `{v}`" for k, v in sorted(overrides.items())]
|
||||||
|
await ctx.reply(
|
||||||
|
"Current overrides:\n" + "\n".join(lines),
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Support multi-word value with the key as prefix
|
||||||
|
key = param.split()[0] if " " in param else param
|
||||||
|
val = overrides.get(key)
|
||||||
|
if val is None:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Override `{key}` is not set.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await ctx.reply(
|
||||||
|
f"**{key}**: `{val}`",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to get workflow overrides")
|
||||||
|
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(
|
||||||
|
name="set-current-workflow-changes",
|
||||||
|
aliases=["setworkflowchanges", "scwc"],
|
||||||
|
extras={"category": "Workflow"},
|
||||||
|
)
|
||||||
|
@require_comfy_client
|
||||||
|
async def set_current_workflow_changes_command(
|
||||||
|
ctx: commands.Context, *, args: str = ""
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set a workflow override value.
|
||||||
|
|
||||||
|
Supports any NodeInput.key discovered by WorkflowInspector as well
|
||||||
|
as the legacy fixed keys.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
ttr!set-current-workflow-changes type:<key> <value>
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
ttr!scwc type:prompt A beautiful landscape
|
||||||
|
ttr!scwc type:negative_prompt blurry
|
||||||
|
ttr!scwc type:input_image my_image.png
|
||||||
|
ttr!scwc type:steps 30
|
||||||
|
ttr!scwc type:cfg 7.5
|
||||||
|
ttr!scwc type:seed 42
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not args or ARG_TYPE_KEY not in args:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Usage: `ttr!set-current-workflow-changes {ARG_TYPE_KEY}<key> <value>`",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
rest = args.split(ARG_TYPE_KEY, 1)[1]
|
||||||
|
# Key is the first word; value is everything after the first space
|
||||||
|
parts = rest.split(None, 1)
|
||||||
|
if len(parts) < 2:
|
||||||
|
await ctx.reply(
|
||||||
|
"Please provide both a key and a value. "
|
||||||
|
f"Example: `ttr!scwc {ARG_TYPE_KEY}prompt A cat`",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
key = parts[0].strip().lower()
|
||||||
|
raw_value: str = parts[1].strip()
|
||||||
|
|
||||||
|
if not key:
|
||||||
|
await ctx.reply("Key cannot be empty.", mention_author=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Type-coerce well-known numeric keys
|
||||||
|
_int_keys = {"steps", "width", "height"}
|
||||||
|
_float_keys = {"cfg", "denoise"}
|
||||||
|
_seed_keys = {"seed", "noise_seed"}
|
||||||
|
|
||||||
|
value: object = raw_value
|
||||||
|
try:
|
||||||
|
if key in _int_keys:
|
||||||
|
value = int(raw_value)
|
||||||
|
elif key in _float_keys:
|
||||||
|
value = float(raw_value)
|
||||||
|
elif key in _seed_keys:
|
||||||
|
value = int(raw_value)
|
||||||
|
except ValueError:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Invalid value for `{key}`: expected a number, got `{raw_value}`.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
bot.comfy.state_manager.set_override(key, value)
|
||||||
|
await ctx.reply(
|
||||||
|
f"Override **{key}** set to `{value}`.",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to set workflow override")
|
||||||
|
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(
|
||||||
|
name="clear-workflow-change",
|
||||||
|
aliases=["clearworkflowchange", "cwc"],
|
||||||
|
extras={"category": "Workflow"},
|
||||||
|
)
|
||||||
|
@require_comfy_client
|
||||||
|
async def clear_workflow_change_command(
|
||||||
|
ctx: commands.Context, *, args: str = ""
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Remove a single override key.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
ttr!clear-workflow-change type:<key>
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if ARG_TYPE_KEY not in args:
|
||||||
|
await ctx.reply(
|
||||||
|
f"Usage: `ttr!clear-workflow-change {ARG_TYPE_KEY}<key>`",
|
||||||
|
mention_author=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
key = args.split(ARG_TYPE_KEY, 1)[1].strip().lower()
|
||||||
|
bot.comfy.state_manager.delete_override(key)
|
||||||
|
await ctx.reply(f"Override **{key}** cleared.", mention_author=False)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to clear override")
|
||||||
|
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(
|
||||||
|
name="set-seed",
|
||||||
|
aliases=["setseed"],
|
||||||
|
extras={"category": "Workflow"},
|
||||||
|
)
|
||||||
|
@require_comfy_client
|
||||||
|
async def set_seed_command(ctx: commands.Context, *, args: str = "") -> None:
|
||||||
|
"""
|
||||||
|
Pin a specific seed for deterministic generation.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
ttr!set-seed 42
|
||||||
|
"""
|
||||||
|
seed_str = args.strip()
|
||||||
|
if not seed_str:
|
||||||
|
await ctx.reply("Usage: `ttr!set-seed <number>`", mention_author=False)
|
||||||
|
return
|
||||||
|
if not seed_str.isdigit():
|
||||||
|
await ctx.reply("Seed must be a non-negative integer.", mention_author=False)
|
||||||
|
return
|
||||||
|
seed_val = int(seed_str)
|
||||||
|
max_seed = 2 ** 32 - 1
|
||||||
|
if seed_val > max_seed:
|
||||||
|
await ctx.reply(f"Seed must be between 0 and {max_seed}.", mention_author=False)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
bot.comfy.state_manager.set_seed(seed_val)
|
||||||
|
await ctx.reply(f"Seed pinned to `{seed_val}`.", mention_author=False)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to set seed")
|
||||||
|
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)
|
||||||
|
|
||||||
|
@bot.command(
|
||||||
|
name="clear-seed",
|
||||||
|
aliases=["clearseed"],
|
||||||
|
extras={"category": "Workflow"},
|
||||||
|
)
|
||||||
|
@require_comfy_client
|
||||||
|
async def clear_seed_command(ctx: commands.Context) -> None:
|
||||||
|
"""
|
||||||
|
Clear the pinned seed and return to random generation.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
ttr!clear-seed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
bot.comfy.state_manager.clear_seed()
|
||||||
|
await ctx.reply("Seed cleared; generation will now use random seeds.", mention_author=False)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to clear seed")
|
||||||
|
await ctx.reply(f"An error occurred: {type(exc).__name__}: {exc}", mention_author=False)
|
||||||
283
config.py
Normal file
283
config.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""
|
||||||
|
config.py
|
||||||
|
=========
|
||||||
|
|
||||||
|
Configuration module for the Discord ComfyUI bot.
|
||||||
|
This module centralizes all constants, magic strings, and environment
|
||||||
|
variable loading to make configuration management easier and more maintainable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Command and Argument Constants
|
||||||
|
# ========================================
|
||||||
|
|
||||||
|
COMMAND_PREFIX = os.getenv("BOT_PREFIX", "ttr!")
|
||||||
|
"""The command prefix used for Discord bot commands."""
|
||||||
|
|
||||||
|
ARG_PROMPT_KEY = "prompt:"
|
||||||
|
"""The keyword marker for prompt arguments in commands."""
|
||||||
|
|
||||||
|
ARG_NEG_PROMPT_KEY = "negative_prompt:"
|
||||||
|
"""The keyword marker for negative prompt arguments in commands."""
|
||||||
|
|
||||||
|
ARG_TYPE_KEY = "type:"
|
||||||
|
"""The keyword marker for type arguments in commands."""
|
||||||
|
|
||||||
|
ARG_QUEUE_KEY = "queue:"
|
||||||
|
"""The keyword marker for queue count arguments in commands."""
|
||||||
|
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Discord and Message Constants
|
||||||
|
# ========================================
|
||||||
|
|
||||||
|
MAX_IMAGES_PER_RESPONSE = 4
|
||||||
|
"""Maximum number of images to include in a single Discord response."""
|
||||||
|
|
||||||
|
DEFAULT_UPLOAD_TYPE = "input"
|
||||||
|
"""Default folder type for ComfyUI image uploads."""
|
||||||
|
|
||||||
|
MESSAGE_AUTO_DELETE_TIMEOUT = 60.0
|
||||||
|
"""Default timeout in seconds for auto-deleting temporary messages."""
|
||||||
|
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Error Messages
|
||||||
|
# ========================================
|
||||||
|
|
||||||
|
COMFY_NOT_CONFIGURED_MSG = "ComfyUI client is not configured. Please set environment variables."
|
||||||
|
"""Error message displayed when ComfyUI client is not properly configured."""
|
||||||
|
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Default Configuration Values
|
||||||
|
# ========================================
|
||||||
|
|
||||||
|
DEFAULT_COMFY_HISTORY_LIMIT = 10
|
||||||
|
"""Default number of generation history entries to keep."""
|
||||||
|
|
||||||
|
# Resolve paths relative to this file's location so both the bot project and
|
||||||
|
# the portable ComfyUI folder only need to share the same parent directory.
|
||||||
|
# Layout assumed:
|
||||||
|
# <parent>/
|
||||||
|
# ComfyUI_windows_portable/ComfyUI/output ← default output
|
||||||
|
# ComfyUI_windows_portable/ComfyUI/input ← default input
|
||||||
|
# the-third-rev/ ← this project
|
||||||
|
_COMFY_PORTABLE_ROOT = Path(__file__).resolve().parent.parent / "ComfyUI_windows_portable" / "ComfyUI"
|
||||||
|
DEFAULT_COMFY_OUTPUT_PATH = str(_COMFY_PORTABLE_ROOT / "output")
|
||||||
|
DEFAULT_COMFY_INPUT_PATH = str(_COMFY_PORTABLE_ROOT / "input")
|
||||||
|
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Configuration Class
|
||||||
|
# ========================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BotConfig:
|
||||||
|
"""
|
||||||
|
Configuration container for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
This dataclass holds all configuration values loaded from environment
|
||||||
|
variables. Use the `from_env()` class method to create an instance
|
||||||
|
with values loaded from the environment.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
discord_bot_token : str
|
||||||
|
Discord bot authentication token (required).
|
||||||
|
comfy_server : str
|
||||||
|
ComfyUI server address in format "hostname:port" (required).
|
||||||
|
comfy_output_path : str
|
||||||
|
Path to ComfyUI output directory for reading generated files.
|
||||||
|
comfy_history_limit : int
|
||||||
|
Number of generation history entries to keep in memory.
|
||||||
|
workflow_file : Optional[str]
|
||||||
|
Path to a workflow JSON file to load at startup (optional).
|
||||||
|
"""
|
||||||
|
|
||||||
|
discord_bot_token: str
|
||||||
|
comfy_server: str
|
||||||
|
comfy_output_path: str
|
||||||
|
comfy_input_path: str
|
||||||
|
comfy_history_limit: int
|
||||||
|
comfy_input_channel_id: int = 1475791295665405962
|
||||||
|
comfy_service_name: str = "ComfyUI"
|
||||||
|
comfy_start_bat: str = ""
|
||||||
|
comfy_log_dir: str = ""
|
||||||
|
comfy_log_max_mb: int = 10
|
||||||
|
comfy_autostart: bool = True
|
||||||
|
workflow_file: Optional[str] = None
|
||||||
|
log_channel_id: Optional[int] = None
|
||||||
|
zip_password: Optional[str] = None
|
||||||
|
media_upload_user: Optional[str] = None
|
||||||
|
media_upload_pass: Optional[str] = None
|
||||||
|
# Web UI fields
|
||||||
|
web_enabled: bool = True
|
||||||
|
web_host: str = "0.0.0.0"
|
||||||
|
web_port: int = 8080
|
||||||
|
web_secret_key: str = ""
|
||||||
|
web_token_file: str = "invite_tokens.json"
|
||||||
|
web_jwt_expire_hours: int = 8
|
||||||
|
web_secure_cookie: bool = True
|
||||||
|
admin_password: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> BotConfig:
|
||||||
|
"""
|
||||||
|
Create a BotConfig instance by loading values from environment variables.
|
||||||
|
|
||||||
|
Environment Variables
|
||||||
|
---------------------
|
||||||
|
DISCORD_BOT_TOKEN : str (required)
|
||||||
|
Discord bot authentication token.
|
||||||
|
COMFY_SERVER : str (required)
|
||||||
|
ComfyUI server address (e.g., "localhost:8188" or "example.com:8188").
|
||||||
|
COMFY_OUTPUT_PATH : str (optional)
|
||||||
|
Path to ComfyUI output directory. Defaults to DEFAULT_COMFY_OUTPUT_PATH
|
||||||
|
if not specified.
|
||||||
|
COMFY_HISTORY_LIMIT : int (optional)
|
||||||
|
Number of generation history entries to keep. Defaults to
|
||||||
|
DEFAULT_COMFY_HISTORY_LIMIT if not specified or invalid.
|
||||||
|
WORKFLOW_FILE : str (optional)
|
||||||
|
Path to a workflow JSON file to load at startup.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BotConfig
|
||||||
|
A configured BotConfig instance.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
RuntimeError
|
||||||
|
If required environment variables (DISCORD_BOT_TOKEN or COMFY_SERVER)
|
||||||
|
are not set.
|
||||||
|
"""
|
||||||
|
# Load required variables
|
||||||
|
discord_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||||
|
if not discord_token:
|
||||||
|
raise RuntimeError(
|
||||||
|
"DISCORD_BOT_TOKEN environment variable is required. "
|
||||||
|
"Please set it in your .env file or environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
comfy_server = os.getenv("COMFY_SERVER")
|
||||||
|
if not comfy_server:
|
||||||
|
raise RuntimeError(
|
||||||
|
"COMFY_SERVER environment variable is required. "
|
||||||
|
"Please set it in your .env file or environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load optional variables with defaults
|
||||||
|
comfy_output_path = os.getenv("COMFY_OUTPUT_PATH", DEFAULT_COMFY_OUTPUT_PATH)
|
||||||
|
comfy_input_path = os.getenv("COMFY_INPUT_PATH", DEFAULT_COMFY_INPUT_PATH)
|
||||||
|
|
||||||
|
# Parse history limit with fallback to default
|
||||||
|
try:
|
||||||
|
comfy_history_limit = int(os.getenv("COMFY_HISTORY_LIMIT", str(DEFAULT_COMFY_HISTORY_LIMIT)))
|
||||||
|
except ValueError:
|
||||||
|
comfy_history_limit = DEFAULT_COMFY_HISTORY_LIMIT
|
||||||
|
|
||||||
|
workflow_file = os.getenv("WORKFLOW_FILE")
|
||||||
|
|
||||||
|
log_channel_id_str = os.getenv("LOG_CHANNEL_ID", "1475408462740721809")
|
||||||
|
try:
|
||||||
|
log_channel_id = int(log_channel_id_str) if log_channel_id_str else None
|
||||||
|
except ValueError:
|
||||||
|
log_channel_id = None
|
||||||
|
|
||||||
|
zip_password = os.getenv("ZIP_PASSWORD", "0Revel512796@")
|
||||||
|
|
||||||
|
media_upload_user = os.getenv("MEDIA_UPLOAD_USER") or None
|
||||||
|
media_upload_pass = os.getenv("MEDIA_UPLOAD_PASS") or None
|
||||||
|
|
||||||
|
try:
|
||||||
|
comfy_input_channel_id = int(os.getenv("COMFY_INPUT_CHANNEL_ID", "1475791295665405962"))
|
||||||
|
except ValueError:
|
||||||
|
comfy_input_channel_id = 1475791295665405962
|
||||||
|
|
||||||
|
comfy_service_name = os.getenv("COMFY_SERVICE_NAME", "ComfyUI")
|
||||||
|
|
||||||
|
default_bat = str(_COMFY_PORTABLE_ROOT.parent / "run_nvidia_gpu.bat")
|
||||||
|
comfy_start_bat = os.getenv("COMFY_START_BAT", default_bat)
|
||||||
|
|
||||||
|
default_log_dir = str(_COMFY_PORTABLE_ROOT.parent / "logs")
|
||||||
|
comfy_log_dir = os.getenv("COMFY_LOG_DIR", default_log_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
comfy_log_max_mb = int(os.getenv("COMFY_LOG_MAX_MB", "10"))
|
||||||
|
except ValueError:
|
||||||
|
comfy_log_max_mb = 10
|
||||||
|
|
||||||
|
comfy_autostart = os.getenv("COMFY_AUTOSTART", "true").lower() not in ("false", "0", "no")
|
||||||
|
|
||||||
|
# Web UI config
|
||||||
|
web_enabled = os.getenv("WEB_ENABLED", "true").lower() not in ("false", "0", "no")
|
||||||
|
web_host = os.getenv("WEB_HOST", "0.0.0.0")
|
||||||
|
try:
|
||||||
|
web_port = int(os.getenv("WEB_PORT", "8080"))
|
||||||
|
except ValueError:
|
||||||
|
web_port = 8080
|
||||||
|
web_secret_key = os.getenv("WEB_SECRET_KEY", "")
|
||||||
|
web_token_file = os.getenv("WEB_TOKEN_FILE", "invite_tokens.json")
|
||||||
|
try:
|
||||||
|
web_jwt_expire_hours = int(os.getenv("WEB_JWT_EXPIRE_HOURS", "8"))
|
||||||
|
except ValueError:
|
||||||
|
web_jwt_expire_hours = 8
|
||||||
|
web_secure_cookie = os.getenv("WEB_SECURE_COOKIE", "true").lower() not in ("false", "0", "no")
|
||||||
|
admin_password = os.getenv("ADMIN_PASSWORD") or None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
discord_bot_token=discord_token,
|
||||||
|
comfy_server=comfy_server,
|
||||||
|
comfy_output_path=comfy_output_path,
|
||||||
|
comfy_input_path=comfy_input_path,
|
||||||
|
comfy_history_limit=comfy_history_limit,
|
||||||
|
comfy_input_channel_id=comfy_input_channel_id,
|
||||||
|
comfy_service_name=comfy_service_name,
|
||||||
|
comfy_start_bat=comfy_start_bat,
|
||||||
|
comfy_log_dir=comfy_log_dir,
|
||||||
|
comfy_log_max_mb=comfy_log_max_mb,
|
||||||
|
comfy_autostart=comfy_autostart,
|
||||||
|
workflow_file=workflow_file,
|
||||||
|
log_channel_id=log_channel_id,
|
||||||
|
zip_password=zip_password,
|
||||||
|
media_upload_user=media_upload_user,
|
||||||
|
media_upload_pass=media_upload_pass,
|
||||||
|
web_enabled=web_enabled,
|
||||||
|
web_host=web_host,
|
||||||
|
web_port=web_port,
|
||||||
|
web_secret_key=web_secret_key,
|
||||||
|
web_token_file=web_token_file,
|
||||||
|
web_jwt_expire_hours=web_jwt_expire_hours,
|
||||||
|
web_secure_cookie=web_secure_cookie,
|
||||||
|
admin_password=admin_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Return a string representation with sensitive data masked."""
|
||||||
|
return (
|
||||||
|
f"BotConfig("
|
||||||
|
f"discord_bot_token='***masked***', "
|
||||||
|
f"comfy_server='{self.comfy_server}', "
|
||||||
|
f"comfy_output_path='{self.comfy_output_path}', "
|
||||||
|
f"comfy_input_path='{self.comfy_input_path}', "
|
||||||
|
f"comfy_history_limit={self.comfy_history_limit}, "
|
||||||
|
f"comfy_input_channel_id={self.comfy_input_channel_id}, "
|
||||||
|
f"workflow_file={self.workflow_file!r}, "
|
||||||
|
f"log_channel_id={self.log_channel_id!r}, "
|
||||||
|
f"zip_password={'***masked***' if self.zip_password else None})"
|
||||||
|
)
|
||||||
251
discord_utils.py
Normal file
251
discord_utils.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
"""
|
||||||
|
discord_utils.py
|
||||||
|
================
|
||||||
|
|
||||||
|
Discord utility functions and helpers for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
This module provides reusable Discord-specific utilities including:
|
||||||
|
- Command decorators for validation
|
||||||
|
- Argument parsing helpers
|
||||||
|
- Message formatting utilities
|
||||||
|
- Discord UI components
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Dict, Optional, List, Tuple
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
from discord.ui import View
|
||||||
|
|
||||||
|
from config import COMFY_NOT_CONFIGURED_MSG
|
||||||
|
|
||||||
|
|
||||||
|
def require_comfy_client(func):
|
||||||
|
"""
|
||||||
|
Decorator that validates bot.comfy exists before executing a command.
|
||||||
|
|
||||||
|
This decorator checks if the bot has a configured ComfyClient instance
|
||||||
|
(bot.comfy) and sends an error message if not. This eliminates the need
|
||||||
|
for repeated validation code in every command.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
@bot.command(name="generate")
|
||||||
|
@require_comfy_client
|
||||||
|
async def generate_command(ctx: commands.Context, *, args: str = ""):
|
||||||
|
# bot.comfy is guaranteed to exist here
|
||||||
|
await bot.comfy.generate_image(...)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
func : callable
|
||||||
|
The command function to wrap.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
callable
|
||||||
|
The wrapped command function with ComfyClient validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(ctx: commands.Context, *args, **kwargs):
|
||||||
|
bot = ctx.bot
|
||||||
|
if not hasattr(bot, "comfy") or bot.comfy is None:
|
||||||
|
await ctx.reply(COMFY_NOT_CONFIGURED_MSG, mention_author=False)
|
||||||
|
return
|
||||||
|
return await func(ctx, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def parse_labeled_args(args: str, keys: List[str]) -> Dict[str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Parse labeled arguments from a command string.
|
||||||
|
|
||||||
|
This parser handles Discord command arguments in the format:
|
||||||
|
"key1:value1 key2:value2 ..."
|
||||||
|
|
||||||
|
The parser splits on keyword markers and preserves case in values.
|
||||||
|
If a key is not found in the args string, its value will be None.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : str
|
||||||
|
The argument string to parse.
|
||||||
|
keys : List[str]
|
||||||
|
List of keys to extract (e.g., ["prompt:", "negative_prompt:"]).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, Optional[str]]
|
||||||
|
Dictionary mapping keys (without colons) to their values.
|
||||||
|
Keys not found in args will have None values.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> parse_labeled_args("prompt:a cat negative_prompt:blurry", ["prompt:", "negative_prompt:"])
|
||||||
|
{"prompt": "a cat", "negative_prompt": "blurry"}
|
||||||
|
|
||||||
|
>>> parse_labeled_args("prompt:hello world", ["prompt:", "type:"])
|
||||||
|
{"prompt": "hello world", "type": None}
|
||||||
|
"""
|
||||||
|
result = {key.rstrip(":"): None for key in keys}
|
||||||
|
remaining = args
|
||||||
|
|
||||||
|
# Sort keys by position in string to parse left-to-right
|
||||||
|
found_keys = []
|
||||||
|
for key in keys:
|
||||||
|
if key in remaining:
|
||||||
|
idx = remaining.find(key)
|
||||||
|
found_keys.append((idx, key))
|
||||||
|
|
||||||
|
found_keys.sort()
|
||||||
|
|
||||||
|
for i, (_, key) in enumerate(found_keys):
|
||||||
|
# Split on this key
|
||||||
|
parts = remaining.split(key, 1)
|
||||||
|
if len(parts) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value_part = parts[1]
|
||||||
|
|
||||||
|
# Find the next key, if any
|
||||||
|
next_key_idx = len(value_part)
|
||||||
|
if i + 1 < len(found_keys):
|
||||||
|
next_key = found_keys[i + 1][1]
|
||||||
|
if next_key in value_part:
|
||||||
|
next_key_idx = value_part.find(next_key)
|
||||||
|
|
||||||
|
# Extract value up to next key
|
||||||
|
value = value_part[:next_key_idx].strip()
|
||||||
|
result[key.rstrip(":")] = value if value else None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def convert_image_bytes_to_discord_files(
|
||||||
|
images: List[bytes], max_files: int = 4, prefix: str = "generated"
|
||||||
|
) -> List[discord.File]:
|
||||||
|
"""
|
||||||
|
Convert a list of image bytes to Discord File objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
images : List[bytes]
|
||||||
|
List of raw image data as bytes.
|
||||||
|
max_files : int
|
||||||
|
Maximum number of files to convert (default: 4, Discord's limit).
|
||||||
|
prefix : str
|
||||||
|
Filename prefix for generated files (default: "generated").
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[discord.File]
|
||||||
|
List of Discord.File objects ready to send.
|
||||||
|
"""
|
||||||
|
files = []
|
||||||
|
for idx, img_bytes in enumerate(images):
|
||||||
|
if idx >= max_files:
|
||||||
|
break
|
||||||
|
file_obj = BytesIO(img_bytes)
|
||||||
|
file_obj.seek(0)
|
||||||
|
files.append(discord.File(file_obj, filename=f"{prefix}_{idx + 1}.png"))
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
async def send_queue_status(ctx: commands.Context, queue_size: int) -> None:
|
||||||
|
"""
|
||||||
|
Send a queue status message to the channel.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ctx : commands.Context
|
||||||
|
The command context.
|
||||||
|
queue_size : int
|
||||||
|
Current number of jobs in the queue.
|
||||||
|
"""
|
||||||
|
await ctx.send(f"Queue size: {queue_size}", mention_author=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def send_typing_with_callback(ctx: commands.Context, callback):
|
||||||
|
"""
|
||||||
|
Execute a callback while showing typing indicator.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ctx : commands.Context
|
||||||
|
The command context.
|
||||||
|
callback : callable
|
||||||
|
Async function to execute while typing.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Any
|
||||||
|
The return value of the callback.
|
||||||
|
"""
|
||||||
|
async with ctx.typing():
|
||||||
|
return await callback()
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_text(text: str, length: int = 50) -> str:
|
||||||
|
"""
|
||||||
|
Truncate text to a maximum length with ellipsis.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
text : str
|
||||||
|
The text to truncate.
|
||||||
|
length : int
|
||||||
|
Maximum length (default: 50).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Truncated text with "..." suffix if longer than length.
|
||||||
|
"""
|
||||||
|
return text if len(text) <= length else text[: length - 3] + "..."
|
||||||
|
|
||||||
|
|
||||||
|
def extract_arg_value(args: str, key: str) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Extract a single argument value from a labeled args string.
|
||||||
|
|
||||||
|
This is a simpler alternative to parse_labeled_args for extracting just
|
||||||
|
one value.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : str
|
||||||
|
The full argument string.
|
||||||
|
key : str
|
||||||
|
The key to extract (e.g., "type:").
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[Optional[str], str]
|
||||||
|
A tuple of (extracted_value, remaining_args). If key not found,
|
||||||
|
returns (None, original_args).
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> extract_arg_value("type:input some other text", "type:")
|
||||||
|
("input", "some other text")
|
||||||
|
"""
|
||||||
|
if key not in args:
|
||||||
|
return None, args
|
||||||
|
|
||||||
|
parts = args.split(key, 1)
|
||||||
|
if len(parts) < 2:
|
||||||
|
return None, args
|
||||||
|
|
||||||
|
value_and_rest = parts[1].strip()
|
||||||
|
# Take first word as value
|
||||||
|
words = value_and_rest.split(None, 1)
|
||||||
|
value = words[0] if words else None
|
||||||
|
remaining = words[1] if len(words) > 1 else ""
|
||||||
|
|
||||||
|
return value, remaining
|
||||||
12
frontend/index.html
Normal file
12
frontend/index.html
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>ComfyUI Bot</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.tsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
2774
frontend/package-lock.json
generated
Normal file
2774
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
27
frontend/package.json
Normal file
27
frontend/package.json
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"name": "comfyui-bot-ui",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"private": true,
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "tsc && vite build",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"react": "^18.3.1",
|
||||||
|
"react-dom": "^18.3.1",
|
||||||
|
"react-router-dom": "^6.27.0",
|
||||||
|
"@tanstack/react-query": "^5.62.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/react": "^18.3.12",
|
||||||
|
"@types/react-dom": "^18.3.1",
|
||||||
|
"@vitejs/plugin-react": "^4.3.3",
|
||||||
|
"autoprefixer": "^10.4.20",
|
||||||
|
"postcss": "^8.4.49",
|
||||||
|
"tailwindcss": "^3.4.15",
|
||||||
|
"typescript": "^5.6.3",
|
||||||
|
"vite": "^5.4.11"
|
||||||
|
}
|
||||||
|
}
|
||||||
6
frontend/postcss.config.js
Normal file
6
frontend/postcss.config.js
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
export default {
|
||||||
|
plugins: {
|
||||||
|
tailwindcss: {},
|
||||||
|
autoprefixer: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
68
frontend/src/App.tsx
Normal file
68
frontend/src/App.tsx
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { BrowserRouter, Routes, Route, Navigate } from 'react-router-dom'
|
||||||
|
import { useAuth } from './hooks/useAuth'
|
||||||
|
import Layout from './components/Layout'
|
||||||
|
import { GenerationProvider } from './context/GenerationContext'
|
||||||
|
import LoginPage from './pages/LoginPage'
|
||||||
|
import GeneratePage from './pages/GeneratePage'
|
||||||
|
import InputImagesPage from './pages/InputImagesPage'
|
||||||
|
import WorkflowPage from './pages/WorkflowPage'
|
||||||
|
import PresetsPage from './pages/PresetsPage'
|
||||||
|
import StatusPage from './pages/StatusPage'
|
||||||
|
import ServerPage from './pages/ServerPage'
|
||||||
|
import HistoryPage from './pages/HistoryPage'
|
||||||
|
import AdminPage from './pages/AdminPage'
|
||||||
|
import SharePage from './pages/SharePage'
|
||||||
|
|
||||||
|
function RequireAuth({ children }: { children: React.ReactNode }) {
|
||||||
|
const { isAuthenticated, isLoading } = useAuth()
|
||||||
|
if (isLoading) return <div className="flex items-center justify-center h-screen text-gray-500">Loading...</div>
|
||||||
|
if (!isAuthenticated) return <Navigate to="/login" replace />
|
||||||
|
return <>{children}</>
|
||||||
|
}
|
||||||
|
|
||||||
|
function RequireAdmin({ children }: { children: React.ReactNode }) {
|
||||||
|
const { isAdmin, isLoading } = useAuth()
|
||||||
|
if (isLoading) return null
|
||||||
|
if (!isAdmin) return <Navigate to="/generate" replace />
|
||||||
|
return <>{children}</>
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function App() {
|
||||||
|
return (
|
||||||
|
<BrowserRouter>
|
||||||
|
<Routes>
|
||||||
|
<Route path="/login" element={<LoginPage />} />
|
||||||
|
<Route
|
||||||
|
path="/"
|
||||||
|
element={
|
||||||
|
<RequireAuth>
|
||||||
|
<GenerationProvider>
|
||||||
|
<Layout />
|
||||||
|
</GenerationProvider>
|
||||||
|
</RequireAuth>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Route index element={<Navigate to="/generate" replace />} />
|
||||||
|
<Route path="generate" element={<GeneratePage />} />
|
||||||
|
<Route path="inputs" element={<InputImagesPage />} />
|
||||||
|
<Route path="workflow" element={<WorkflowPage />} />
|
||||||
|
<Route path="presets" element={<PresetsPage />} />
|
||||||
|
<Route path="status" element={<StatusPage />} />
|
||||||
|
<Route path="server" element={<ServerPage />} />
|
||||||
|
<Route path="history" element={<HistoryPage />} />
|
||||||
|
<Route
|
||||||
|
path="admin"
|
||||||
|
element={
|
||||||
|
<RequireAdmin>
|
||||||
|
<AdminPage />
|
||||||
|
</RequireAdmin>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Route>
|
||||||
|
<Route path="/share/:token" element={<SharePage />} />
|
||||||
|
<Route path="*" element={<Navigate to="/generate" replace />} />
|
||||||
|
</Routes>
|
||||||
|
</BrowserRouter>
|
||||||
|
)
|
||||||
|
}
|
||||||
198
frontend/src/api/client.ts
Normal file
198
frontend/src/api/client.ts
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
/** Typed API client for the ComfyUI Bot web API. */
|
||||||
|
|
||||||
|
const BASE = '' // same-origin in prod; Vite proxy in dev
|
||||||
|
|
||||||
|
async function _fetch<T>(path: string, init?: RequestInit): Promise<T> {
|
||||||
|
const res = await fetch(BASE + path, {
|
||||||
|
credentials: 'include',
|
||||||
|
headers: { 'Content-Type': 'application/json', ...init?.headers },
|
||||||
|
...init,
|
||||||
|
})
|
||||||
|
if (!res.ok) {
|
||||||
|
const msg = await res.text().catch(() => res.statusText)
|
||||||
|
throw new Error(`${res.status}: ${msg}`)
|
||||||
|
}
|
||||||
|
return res.json() as Promise<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth
|
||||||
|
export const authLogin = (token: string) =>
|
||||||
|
_fetch<{ label: string; admin: boolean }>('/api/auth/login', {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify({ token }),
|
||||||
|
})
|
||||||
|
|
||||||
|
export const authLogout = () =>
|
||||||
|
_fetch<{ ok: boolean }>('/api/auth/logout', { method: 'POST' })
|
||||||
|
|
||||||
|
export const authMe = () =>
|
||||||
|
_fetch<{ label: string; admin: boolean }>('/api/auth/me')
|
||||||
|
|
||||||
|
// Admin
|
||||||
|
export const adminLogin = (password: string) =>
|
||||||
|
_fetch<{ label: string; admin: boolean }>('/api/admin/login', {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify({ password }),
|
||||||
|
})
|
||||||
|
|
||||||
|
export const adminListTokens = () =>
|
||||||
|
_fetch<Array<{ id: string; label: string; admin: boolean; created_at: string }>>('/api/admin/tokens')
|
||||||
|
|
||||||
|
export const adminCreateToken = (label: string, admin = false) =>
|
||||||
|
_fetch<{ token: string; label: string; admin: boolean }>('/api/admin/tokens', {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify({ label, admin }),
|
||||||
|
})
|
||||||
|
|
||||||
|
export const adminRevokeToken = (id: string) =>
|
||||||
|
_fetch<{ ok: boolean }>(`/api/admin/tokens/${id}`, { method: 'DELETE' })
|
||||||
|
|
||||||
|
// Status
|
||||||
|
export const getStatus = () => _fetch<Record<string, unknown>>('/api/status')
|
||||||
|
|
||||||
|
// State / overrides
|
||||||
|
export const getState = () => _fetch<Record<string, unknown>>('/api/state')
|
||||||
|
|
||||||
|
export const putState = (overrides: Record<string, unknown>) =>
|
||||||
|
_fetch<Record<string, unknown>>('/api/state', {
|
||||||
|
method: 'PUT',
|
||||||
|
body: JSON.stringify(overrides),
|
||||||
|
})
|
||||||
|
|
||||||
|
export const deleteStateKey = (key: string) =>
|
||||||
|
_fetch<{ ok: boolean }>(`/api/state/${key}`, { method: 'DELETE' })
|
||||||
|
|
||||||
|
// Generation
|
||||||
|
export interface GenerateRequest {
|
||||||
|
prompt: string
|
||||||
|
negative_prompt?: string
|
||||||
|
overrides?: Record<string, unknown>
|
||||||
|
}
|
||||||
|
export const generate = (body: GenerateRequest) =>
|
||||||
|
_fetch<{ queued: boolean; queue_position: number }>('/api/generate', {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify(body),
|
||||||
|
})
|
||||||
|
|
||||||
|
export interface WorkflowGenRequest {
|
||||||
|
count?: number
|
||||||
|
overrides?: Record<string, unknown>
|
||||||
|
}
|
||||||
|
export const workflowGen = (body: WorkflowGenRequest) =>
|
||||||
|
_fetch<{ queued: boolean; count: number; queue_position: number }>('/api/workflow-gen', {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify(body),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Inputs
|
||||||
|
export interface InputImage {
|
||||||
|
id: number
|
||||||
|
original_message_id: number
|
||||||
|
bot_reply_id: number | null
|
||||||
|
channel_id: number
|
||||||
|
filename: string
|
||||||
|
is_active: number
|
||||||
|
active_slot_key: string | null
|
||||||
|
}
|
||||||
|
export const listInputs = () => _fetch<InputImage[]>('/api/inputs')
|
||||||
|
|
||||||
|
export const uploadInput = (file: File, slotKey = 'input_image') => {
|
||||||
|
const form = new FormData()
|
||||||
|
form.append('file', file)
|
||||||
|
form.append('slot_key', slotKey)
|
||||||
|
return fetch('/api/inputs', {
|
||||||
|
method: 'POST',
|
||||||
|
credentials: 'include',
|
||||||
|
body: form,
|
||||||
|
}).then(r => r.json())
|
||||||
|
}
|
||||||
|
|
||||||
|
export const activateInput = (id: number, slotKey = 'input_image') =>
|
||||||
|
_fetch<{ ok: boolean }>(`/api/inputs/${id}/activate?slot_key=${slotKey}`, { method: 'POST' })
|
||||||
|
|
||||||
|
export const deleteInput = (id: number) =>
|
||||||
|
_fetch<{ ok: boolean }>(`/api/inputs/${id}`, { method: 'DELETE' })
|
||||||
|
|
||||||
|
export const getInputImage = (id: number) => `/api/inputs/${id}/image`
|
||||||
|
export const getInputThumb = (id: number) => `/api/inputs/${id}/thumb`
|
||||||
|
export const getInputMid = (id: number) => `/api/inputs/${id}/mid`
|
||||||
|
|
||||||
|
// Presets
|
||||||
|
export interface PresetMeta { name: string; owner: string | null; description: string | null }
|
||||||
|
export const listPresets = () => _fetch<{ presets: PresetMeta[] }>('/api/presets')
|
||||||
|
export const savePreset = (name: string, description?: string) =>
|
||||||
|
_fetch<{ ok: boolean }>('/api/presets', { method: 'POST', body: JSON.stringify({ name, description: description ?? null }) })
|
||||||
|
export const getPreset = (name: string) => _fetch<Record<string, unknown>>(`/api/presets/${name}`)
|
||||||
|
export const loadPreset = (name: string) =>
|
||||||
|
_fetch<{ ok: boolean }>(`/api/presets/${name}/load`, { method: 'POST' })
|
||||||
|
export const deletePreset = (name: string) =>
|
||||||
|
_fetch<{ ok: boolean }>(`/api/presets/${name}`, { method: 'DELETE' })
|
||||||
|
export const savePresetFromHistory = (promptId: string, name: string, description?: string) =>
|
||||||
|
_fetch<{ ok: boolean; name: string }>(`/api/presets/from-history/${promptId}`, {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify({ name, description: description ?? null }),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Server
|
||||||
|
export const getServerStatus = () =>
|
||||||
|
_fetch<{ service_state: string; http_reachable: boolean }>('/api/server/status')
|
||||||
|
export const serverAction = (action: string) =>
|
||||||
|
_fetch<{ ok: boolean }>(`/api/server/${action}`, { method: 'POST' })
|
||||||
|
export const tailLogs = (lines = 100) =>
|
||||||
|
_fetch<{ lines: string[] }>(`/api/logs/tail?lines=${lines}`)
|
||||||
|
|
||||||
|
// History
|
||||||
|
export const getHistory = (q?: string) =>
|
||||||
|
_fetch<{ history: Array<Record<string, unknown>> }>(q ? `/api/history?q=${encodeURIComponent(q)}` : '/api/history')
|
||||||
|
|
||||||
|
export const createHistoryShare = (promptId: string) =>
|
||||||
|
_fetch<{ share_token: string }>(`/api/history/${promptId}/share`, { method: 'POST' })
|
||||||
|
|
||||||
|
export const revokeHistoryShare = (promptId: string) =>
|
||||||
|
_fetch<{ ok: boolean }>(`/api/history/${promptId}/share`, { method: 'DELETE' })
|
||||||
|
|
||||||
|
export const getShareFileUrl = (token: string, filename: string) =>
|
||||||
|
`/api/share/${token}/file/${encodeURIComponent(filename)}`
|
||||||
|
|
||||||
|
// Workflow
|
||||||
|
export const getWorkflow = () =>
|
||||||
|
_fetch<{ loaded: boolean; node_count: number; last_workflow_file: string | null }>('/api/workflow')
|
||||||
|
|
||||||
|
export interface NodeInput {
|
||||||
|
key: string
|
||||||
|
label: string
|
||||||
|
input_type: string
|
||||||
|
current_value: unknown
|
||||||
|
node_class: string
|
||||||
|
node_title: string
|
||||||
|
is_common: boolean
|
||||||
|
}
|
||||||
|
export const getWorkflowInputs = () =>
|
||||||
|
_fetch<{ common: NodeInput[]; advanced: NodeInput[] }>('/api/workflow/inputs')
|
||||||
|
|
||||||
|
export const listWorkflowFiles = () =>
|
||||||
|
_fetch<{ files: string[] }>('/api/workflow/files')
|
||||||
|
|
||||||
|
export const uploadWorkflow = (file: File) => {
|
||||||
|
const form = new FormData()
|
||||||
|
form.append('file', file)
|
||||||
|
return fetch('/api/workflow/upload', {
|
||||||
|
method: 'POST',
|
||||||
|
credentials: 'include',
|
||||||
|
body: form,
|
||||||
|
}).then(r => r.json())
|
||||||
|
}
|
||||||
|
|
||||||
|
export const loadWorkflow = (filename: string) => {
|
||||||
|
const form = new FormData()
|
||||||
|
form.append('filename', filename)
|
||||||
|
return fetch('/api/workflow/load', {
|
||||||
|
method: 'POST',
|
||||||
|
credentials: 'include',
|
||||||
|
body: form,
|
||||||
|
}).then(r => r.json())
|
||||||
|
}
|
||||||
|
|
||||||
|
export const getModels = (type: 'checkpoints' | 'loras') =>
|
||||||
|
_fetch<{ type: string; models: string[] }>(`/api/workflow/models?type=${type}`)
|
||||||
|
|
||||||
329
frontend/src/components/DynamicWorkflowForm.tsx
Normal file
329
frontend/src/components/DynamicWorkflowForm.tsx
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import {
|
||||||
|
getWorkflowInputs,
|
||||||
|
getModels,
|
||||||
|
putState,
|
||||||
|
deleteStateKey,
|
||||||
|
getState,
|
||||||
|
NodeInput,
|
||||||
|
activateInput,
|
||||||
|
listInputs,
|
||||||
|
getInputImage,
|
||||||
|
getInputThumb,
|
||||||
|
getInputMid,
|
||||||
|
} from '../api/client'
|
||||||
|
import LazyImage from './LazyImage'
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
/** Called when the Generate button is clicked with the current overrides */
|
||||||
|
onGenerate: (overrides: Record<string, unknown>, count: number) => void
|
||||||
|
/** Live seed from WS generation_complete event */
|
||||||
|
lastSeed?: number | null
|
||||||
|
generating?: boolean
|
||||||
|
/** The authenticated user's label (used to find their active input image) */
|
||||||
|
userLabel?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function DynamicWorkflowForm({ onGenerate, lastSeed, generating, userLabel }: Props) {
|
||||||
|
const qc = useQueryClient()
|
||||||
|
const { data: inputsData, isLoading: inputsLoading } = useQuery({
|
||||||
|
queryKey: ['workflow', 'inputs'],
|
||||||
|
queryFn: getWorkflowInputs,
|
||||||
|
})
|
||||||
|
const { data: stateData } = useQuery({
|
||||||
|
queryKey: ['state'],
|
||||||
|
queryFn: getState,
|
||||||
|
})
|
||||||
|
const { data: checkpoints } = useQuery({
|
||||||
|
queryKey: ['models', 'checkpoints'],
|
||||||
|
queryFn: () => getModels('checkpoints'),
|
||||||
|
staleTime: 60_000,
|
||||||
|
})
|
||||||
|
const { data: loras } = useQuery({
|
||||||
|
queryKey: ['models', 'loras'],
|
||||||
|
queryFn: () => getModels('loras'),
|
||||||
|
staleTime: 60_000,
|
||||||
|
})
|
||||||
|
const { data: inputImages } = useQuery({
|
||||||
|
queryKey: ['inputs'],
|
||||||
|
queryFn: listInputs,
|
||||||
|
})
|
||||||
|
|
||||||
|
const [localValues, setLocalValues] = useState<Record<string, unknown>>({})
|
||||||
|
const [randomSeeds, setRandomSeeds] = useState<Record<string, boolean>>({})
|
||||||
|
const [imagePicker, setImagePicker] = useState<string | null>(null) // key of slot being picked
|
||||||
|
const [count, setCount] = useState(1)
|
||||||
|
|
||||||
|
// Sync local values from state when stateData arrives
|
||||||
|
useEffect(() => {
|
||||||
|
if (stateData) setLocalValues(stateData as Record<string, unknown>)
|
||||||
|
}, [stateData])
|
||||||
|
|
||||||
|
// Update seed field when WS reports completed seed
|
||||||
|
useEffect(() => {
|
||||||
|
if (lastSeed != null) {
|
||||||
|
setLocalValues(v => ({ ...v, seed: lastSeed }))
|
||||||
|
}
|
||||||
|
}, [lastSeed])
|
||||||
|
|
||||||
|
const putStateMut = useMutation({
|
||||||
|
mutationFn: (overrides: Record<string, unknown>) => putState(overrides),
|
||||||
|
onSuccess: () => qc.invalidateQueries({ queryKey: ['state'] }),
|
||||||
|
})
|
||||||
|
const deleteKeyMut = useMutation({
|
||||||
|
mutationFn: (key: string) => deleteStateKey(key),
|
||||||
|
onSuccess: () => qc.invalidateQueries({ queryKey: ['state'] }),
|
||||||
|
})
|
||||||
|
|
||||||
|
const setValue = (key: string, value: unknown) => {
|
||||||
|
setLocalValues(v => ({ ...v, [key]: value }))
|
||||||
|
putStateMut.mutate({ [key]: value })
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleActivateImage = async (imageId: number, slotKey: string) => {
|
||||||
|
await activateInput(imageId, slotKey)
|
||||||
|
qc.invalidateQueries({ queryKey: ['inputs'] })
|
||||||
|
qc.invalidateQueries({ queryKey: ['state'] })
|
||||||
|
setImagePicker(null)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleGenerate = () => {
|
||||||
|
const overrides: Record<string, unknown> = {}
|
||||||
|
const allInputs = [...(inputsData?.common ?? []), ...(inputsData?.advanced ?? [])]
|
||||||
|
for (const inp of allInputs) {
|
||||||
|
if (inp.input_type === 'seed') {
|
||||||
|
overrides[inp.key] = randomSeeds[inp.key] !== false ? -1 : (localValues[inp.key] ?? -1)
|
||||||
|
} else if (inp.input_type === 'image') {
|
||||||
|
// image slot — server reads from state_manager
|
||||||
|
} else {
|
||||||
|
const v = localValues[inp.key]
|
||||||
|
if (v !== undefined && v !== '') overrides[inp.key] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
onGenerate(overrides, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputsLoading) return <div className="text-sm text-gray-400">Loading workflow inputs…</div>
|
||||||
|
if (!inputsData) return <div className="text-sm text-gray-400">No workflow loaded.</div>
|
||||||
|
|
||||||
|
const renderField = (inp: NodeInput) => {
|
||||||
|
const val = localValues[inp.key] ?? inp.current_value
|
||||||
|
|
||||||
|
if (inp.input_type === 'text') {
|
||||||
|
return (
|
||||||
|
<textarea
|
||||||
|
rows={3}
|
||||||
|
className="w-full border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 resize-y focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||||
|
value={String(val ?? '')}
|
||||||
|
onChange={e => setValue(inp.key, e.target.value)}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inp.input_type === 'seed') {
|
||||||
|
const isRandom = randomSeeds[inp.key] !== false
|
||||||
|
return (
|
||||||
|
<div className="flex gap-2 items-center">
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
className="flex-1 border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500 disabled:opacity-40"
|
||||||
|
value={isRandom ? '' : String(val ?? '')}
|
||||||
|
placeholder={isRandom ? 'Random' : undefined}
|
||||||
|
disabled={isRandom}
|
||||||
|
onChange={e => setValue(inp.key, Number(e.target.value))}
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setRandomSeeds(r => ({ ...r, [inp.key]: !isRandom }))}
|
||||||
|
className={`text-xs px-2 py-1 rounded border ${isRandom ? 'bg-blue-600 text-white border-blue-600' : 'border-gray-400 text-gray-600 dark:text-gray-300'}`}
|
||||||
|
>
|
||||||
|
🎲 Random
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inp.input_type === 'image') {
|
||||||
|
const activeFilename = String(val ?? '')
|
||||||
|
const namespacedKey = userLabel ? `${userLabel}_${inp.key}` : inp.key
|
||||||
|
const activeImg = inputImages?.find(i => i.active_slot_key === namespacedKey)
|
||||||
|
return (
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{activeImg ? (
|
||||||
|
<img
|
||||||
|
src={getInputThumb(activeImg.id)}
|
||||||
|
alt={activeImg.filename}
|
||||||
|
className="w-16 h-16 object-cover rounded border border-gray-300 dark:border-gray-600"
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<div className="w-16 h-16 rounded border border-dashed border-gray-400 flex items-center justify-center text-xs text-gray-400">none</div>
|
||||||
|
)}
|
||||||
|
<div className="flex flex-col gap-1">
|
||||||
|
<span className="text-xs text-gray-500 dark:text-gray-400 truncate max-w-[12rem]">{activeFilename || 'No image active'}</span>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setImagePicker(imagePicker === inp.key ? null : inp.key)}
|
||||||
|
className="text-xs bg-gray-200 dark:bg-gray-600 hover:bg-gray-300 dark:hover:bg-gray-500 rounded px-2 py-0.5"
|
||||||
|
>
|
||||||
|
Browse
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inp.input_type === 'checkpoint') {
|
||||||
|
return (
|
||||||
|
<select
|
||||||
|
className="w-full border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||||
|
value={String(val ?? '')}
|
||||||
|
onChange={e => setValue(inp.key, e.target.value)}
|
||||||
|
>
|
||||||
|
<option value="">— select checkpoint —</option>
|
||||||
|
{(checkpoints?.models ?? []).map(m => <option key={m} value={m}>{m}</option>)}
|
||||||
|
</select>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inp.input_type === 'lora') {
|
||||||
|
return (
|
||||||
|
<select
|
||||||
|
className="w-full border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||||
|
value={String(val ?? '')}
|
||||||
|
onChange={e => setValue(inp.key, e.target.value)}
|
||||||
|
>
|
||||||
|
<option value="">— select lora —</option>
|
||||||
|
{(loras?.models ?? []).map(m => <option key={m} value={m}>{m}</option>)}
|
||||||
|
</select>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// integer, float, string
|
||||||
|
return (
|
||||||
|
<input
|
||||||
|
type={inp.input_type === 'integer' || inp.input_type === 'float' ? 'number' : 'text'}
|
||||||
|
step={inp.input_type === 'float' ? 'any' : undefined}
|
||||||
|
className="w-full border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||||
|
value={String(val ?? '')}
|
||||||
|
placeholder={String(inp.current_value ?? '')}
|
||||||
|
onChange={e => {
|
||||||
|
const raw = e.target.value
|
||||||
|
const coerced = inp.input_type === 'integer' ? parseInt(raw) || raw
|
||||||
|
: inp.input_type === 'float' ? parseFloat(raw) || raw
|
||||||
|
: raw
|
||||||
|
setValue(inp.key, coerced)
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
{/* Common inputs */}
|
||||||
|
{inputsData.common.map(inp => (
|
||||||
|
<div key={inp.key}>
|
||||||
|
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-1">{inp.label}</label>
|
||||||
|
{renderField(inp)}
|
||||||
|
{imagePicker === inp.key && (
|
||||||
|
<ImagePickerGrid
|
||||||
|
images={inputImages ?? []}
|
||||||
|
slotKey={inp.key}
|
||||||
|
onPick={handleActivateImage}
|
||||||
|
onClose={() => setImagePicker(null)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
|
||||||
|
{/* Advanced inputs */}
|
||||||
|
{inputsData.advanced.length > 0 && (
|
||||||
|
<details className="border border-gray-200 dark:border-gray-700 rounded">
|
||||||
|
<summary className="px-3 py-2 text-sm font-medium cursor-pointer select-none text-gray-700 dark:text-gray-300">
|
||||||
|
Advanced ({inputsData.advanced.length} inputs)
|
||||||
|
</summary>
|
||||||
|
<div className="px-3 pb-3 space-y-3 mt-2">
|
||||||
|
{inputsData.advanced.map(inp => (
|
||||||
|
<div key={inp.key}>
|
||||||
|
<label className="block text-xs font-medium text-gray-600 dark:text-gray-400 mb-1">{inp.label}</label>
|
||||||
|
{renderField(inp)}
|
||||||
|
{imagePicker === inp.key && (
|
||||||
|
<ImagePickerGrid
|
||||||
|
images={inputImages ?? []}
|
||||||
|
slotKey={inp.key}
|
||||||
|
onPick={handleActivateImage}
|
||||||
|
onClose={() => setImagePicker(null)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex gap-2 items-center">
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
min={1}
|
||||||
|
max={20}
|
||||||
|
value={count}
|
||||||
|
onChange={e => setCount(Math.max(1, Math.min(20, Number(e.target.value))))}
|
||||||
|
className="w-16 border border-gray-300 dark:border-gray-600 rounded px-2 py-2 text-sm text-center bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||||
|
title="Number of generations to queue"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={handleGenerate}
|
||||||
|
disabled={generating}
|
||||||
|
className="flex-1 bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-4 py-2 text-sm font-semibold transition-colors"
|
||||||
|
>
|
||||||
|
{generating ? '⏳ Generating…' : count > 1 ? `Generate ×${count}` : 'Generate'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function ImagePickerGrid({
|
||||||
|
images,
|
||||||
|
slotKey,
|
||||||
|
onPick,
|
||||||
|
onClose,
|
||||||
|
}: {
|
||||||
|
images: { id: number; filename: string }[]
|
||||||
|
slotKey: string
|
||||||
|
onPick: (id: number, key: string) => void
|
||||||
|
onClose: () => void
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="mt-2 border border-gray-300 dark:border-gray-600 rounded p-2 bg-gray-50 dark:bg-gray-800">
|
||||||
|
<div className="flex justify-between items-center mb-2">
|
||||||
|
<span className="text-xs text-gray-500 dark:text-gray-400">Select image for slot: {slotKey}</span>
|
||||||
|
<button onClick={onClose} className="text-xs text-gray-400 hover:text-gray-600 dark:hover:text-gray-200">✕</button>
|
||||||
|
</div>
|
||||||
|
{images.length === 0 ? (
|
||||||
|
<p className="text-xs text-gray-400">No images available. Upload some first.</p>
|
||||||
|
) : (
|
||||||
|
<div className="grid grid-cols-3 sm:grid-cols-4 gap-1 max-h-48 overflow-y-auto">
|
||||||
|
{images.map(img => (
|
||||||
|
<button
|
||||||
|
key={img.id}
|
||||||
|
type="button"
|
||||||
|
onClick={() => onPick(img.id, slotKey)}
|
||||||
|
className="relative aspect-square overflow-hidden rounded border border-gray-200 dark:border-gray-600 hover:border-blue-500"
|
||||||
|
title={img.filename}
|
||||||
|
>
|
||||||
|
<LazyImage
|
||||||
|
thumbSrc={getInputThumb(img.id)}
|
||||||
|
midSrc={getInputMid(img.id)}
|
||||||
|
fullSrc={getInputImage(img.id)}
|
||||||
|
alt={img.filename}
|
||||||
|
className="w-full h-full"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
146
frontend/src/components/Layout.tsx
Normal file
146
frontend/src/components/Layout.tsx
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import React, { useEffect, useRef, useState, useCallback } from 'react'
|
||||||
|
import { NavLink, Outlet, useLocation } from 'react-router-dom'
|
||||||
|
import { useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { useAuth } from '../hooks/useAuth'
|
||||||
|
import { useStatus } from '../hooks/useStatus'
|
||||||
|
import { useGeneration } from '../context/GenerationContext'
|
||||||
|
|
||||||
|
const navItems = [
|
||||||
|
{ to: '/generate', label: 'Generate' },
|
||||||
|
{ to: '/inputs', label: 'Input Images' },
|
||||||
|
{ to: '/workflow', label: 'Workflow' },
|
||||||
|
{ to: '/presets', label: 'Presets' },
|
||||||
|
{ to: '/status', label: 'Status' },
|
||||||
|
{ to: '/server', label: 'Server' },
|
||||||
|
{ to: '/history', label: 'History' },
|
||||||
|
]
|
||||||
|
|
||||||
|
export default function Layout() {
|
||||||
|
const { user, logout, isAdmin } = useAuth()
|
||||||
|
const location = useLocation()
|
||||||
|
const [sidebarOpen, setSidebarOpen] = useState(false)
|
||||||
|
const [dark, setDark] = useState(() => {
|
||||||
|
if (typeof window === 'undefined') return false
|
||||||
|
const stored = localStorage.getItem('dark-mode')
|
||||||
|
if (stored !== null) return stored === 'true'
|
||||||
|
return window.matchMedia('(prefers-color-scheme: dark)').matches
|
||||||
|
})
|
||||||
|
|
||||||
|
// Apply dark class on mount and changes
|
||||||
|
useEffect(() => {
|
||||||
|
document.documentElement.classList.toggle('dark', dark)
|
||||||
|
localStorage.setItem('dark-mode', String(dark))
|
||||||
|
}, [dark])
|
||||||
|
|
||||||
|
// Auto-close sidebar on navigation
|
||||||
|
useEffect(() => {
|
||||||
|
setSidebarOpen(false)
|
||||||
|
}, [location.pathname])
|
||||||
|
|
||||||
|
const { pendingCount, decrementPending } = useGeneration()
|
||||||
|
const queryClient = useQueryClient()
|
||||||
|
const titleResetTimer = useRef<ReturnType<typeof setTimeout> | null>(null)
|
||||||
|
|
||||||
|
const onNodeExecuting = useCallback(() => {
|
||||||
|
document.title = '⏳ Generating… | ComfyUI Bot'
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const onGenerationComplete = useCallback(() => {
|
||||||
|
decrementPending()
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['history'] })
|
||||||
|
if (titleResetTimer.current) clearTimeout(titleResetTimer.current)
|
||||||
|
document.title = 'Done | ComfyUI Bot'
|
||||||
|
titleResetTimer.current = setTimeout(() => { document.title = 'ComfyUI Bot' }, 5000)
|
||||||
|
}, [decrementPending, queryClient])
|
||||||
|
|
||||||
|
const onGenerationError = useCallback(() => {
|
||||||
|
decrementPending()
|
||||||
|
if (titleResetTimer.current) clearTimeout(titleResetTimer.current)
|
||||||
|
document.title = 'Error | ComfyUI Bot'
|
||||||
|
titleResetTimer.current = setTimeout(() => { document.title = 'ComfyUI Bot' }, 5000)
|
||||||
|
}, [decrementPending])
|
||||||
|
|
||||||
|
useEffect(() => () => { document.title = 'ComfyUI Bot' }, [])
|
||||||
|
|
||||||
|
const { status } = useStatus({ enabled: !!user, onGenerationComplete, onGenerationError, onNodeExecuting })
|
||||||
|
const comfyReachable = status.comfy?.reachable ?? null
|
||||||
|
|
||||||
|
const toggleDark = () => setDark(d => !d)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex h-screen overflow-hidden">
|
||||||
|
{/* Mobile backdrop */}
|
||||||
|
{sidebarOpen && (
|
||||||
|
<div
|
||||||
|
className="fixed inset-0 bg-black/40 z-30 md:hidden"
|
||||||
|
onClick={() => setSidebarOpen(false)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Sidebar */}
|
||||||
|
<aside
|
||||||
|
className={`fixed md:static inset-y-0 left-0 z-40 w-48 flex-none bg-gray-800 text-gray-100 flex flex-col transition-transform duration-200 ${
|
||||||
|
sidebarOpen ? 'translate-x-0' : '-translate-x-full md:translate-x-0'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<div className="p-4 font-bold text-lg border-b border-gray-700 flex items-center gap-2">
|
||||||
|
<span>ComfyUI Bot</span>
|
||||||
|
<span
|
||||||
|
title={comfyReachable == null ? 'Connecting…' : comfyReachable ? 'ComfyUI reachable' : 'ComfyUI unreachable'}
|
||||||
|
className={`ml-auto w-2 h-2 rounded-full flex-none ${comfyReachable == null ? 'bg-gray-500' : comfyReachable ? 'bg-green-400' : 'bg-red-400'}`}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<nav className="flex-1 overflow-y-auto py-2">
|
||||||
|
{navItems.map(({ to, label }) => (
|
||||||
|
<NavLink
|
||||||
|
key={to}
|
||||||
|
to={to}
|
||||||
|
className={({ isActive }) =>
|
||||||
|
`flex items-center px-4 py-2 text-sm hover:bg-gray-700 transition-colors ${isActive ? 'bg-gray-700 font-medium' : ''}`
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<span className="flex-1">{label}</span>
|
||||||
|
{to === '/generate' && pendingCount > 0 && (
|
||||||
|
<span className="ml-2 text-xs bg-yellow-400 text-gray-900 rounded-full px-1.5 py-0.5 leading-none">
|
||||||
|
{pendingCount}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</NavLink>
|
||||||
|
))}
|
||||||
|
{isAdmin && (
|
||||||
|
<NavLink
|
||||||
|
to="/admin"
|
||||||
|
className={({ isActive }) =>
|
||||||
|
`block px-4 py-2 text-sm hover:bg-gray-700 transition-colors ${isActive ? 'bg-gray-700 font-medium' : ''}`
|
||||||
|
}
|
||||||
|
>
|
||||||
|
Admin
|
||||||
|
</NavLink>
|
||||||
|
)}
|
||||||
|
</nav>
|
||||||
|
<div className="p-3 border-t border-gray-700 text-xs flex items-center gap-2">
|
||||||
|
<span className="flex-1 truncate">{user?.label ?? '...'}</span>
|
||||||
|
<button onClick={toggleDark} className="hover:text-yellow-300" title="Toggle dark mode">
|
||||||
|
{dark ? '☀' : '🌙'}
|
||||||
|
</button>
|
||||||
|
<button onClick={logout} className="hover:text-red-400" title="Logout">
|
||||||
|
⏏
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</aside>
|
||||||
|
|
||||||
|
{/* Main */}
|
||||||
|
<main className="flex-1 overflow-y-auto p-3 sm:p-6 bg-gray-50 dark:bg-gray-900">
|
||||||
|
{/* Hamburger button — mobile only */}
|
||||||
|
<button
|
||||||
|
className="md:hidden mb-3 p-1.5 rounded bg-gray-200 dark:bg-gray-700 text-gray-700 dark:text-gray-300 hover:bg-gray-300 dark:hover:bg-gray-600"
|
||||||
|
onClick={() => setSidebarOpen(true)}
|
||||||
|
aria-label="Open menu"
|
||||||
|
>
|
||||||
|
☰
|
||||||
|
</button>
|
||||||
|
<Outlet />
|
||||||
|
</main>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
59
frontend/src/components/LazyImage.tsx
Normal file
59
frontend/src/components/LazyImage.tsx
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import React, { useState } from 'react'
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
thumbSrc: string
|
||||||
|
midSrc: string
|
||||||
|
fullSrc: string
|
||||||
|
alt: string
|
||||||
|
className?: string
|
||||||
|
onClick?: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 3-stage progressive image loader:
|
||||||
|
* stage 0 — blurred tiny thumb (loads instantly)
|
||||||
|
* stage 1 — clear medium-compressed image (fades in when ready)
|
||||||
|
* stage 2 — full original image (fades in when ready)
|
||||||
|
*
|
||||||
|
* All three requests fire in parallel; stage only advances forward so if full
|
||||||
|
* arrives before mid we jump straight to stage 2.
|
||||||
|
*/
|
||||||
|
export default function LazyImage({ thumbSrc, midSrc, fullSrc, alt, className = '', onClick }: Props) {
|
||||||
|
const [stage, setStage] = useState(0)
|
||||||
|
const advance = (to: number) => setStage(s => Math.max(s, to))
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={`relative overflow-hidden ${className}`}>
|
||||||
|
{/* Stage 0: blurred thumbnail */}
|
||||||
|
<img
|
||||||
|
src={thumbSrc}
|
||||||
|
alt=""
|
||||||
|
aria-hidden="true"
|
||||||
|
className={`absolute inset-0 w-full h-full object-cover transition-opacity duration-300 ${
|
||||||
|
stage === 0 ? 'opacity-100 blur-sm scale-105' : 'opacity-0'
|
||||||
|
}`}
|
||||||
|
/>
|
||||||
|
{/* Stage 1: clear mid-resolution */}
|
||||||
|
<img
|
||||||
|
src={midSrc}
|
||||||
|
alt=""
|
||||||
|
aria-hidden="true"
|
||||||
|
className={`absolute inset-0 w-full h-full object-cover transition-opacity duration-300 ${
|
||||||
|
stage === 1 ? 'opacity-100' : 'opacity-0'
|
||||||
|
}`}
|
||||||
|
onLoad={() => advance(1)}
|
||||||
|
/>
|
||||||
|
{/* Stage 2: full resolution */}
|
||||||
|
<img
|
||||||
|
src={fullSrc}
|
||||||
|
alt={alt}
|
||||||
|
loading="lazy"
|
||||||
|
className={`relative w-full h-full object-cover transition-opacity duration-300 ${
|
||||||
|
stage === 2 ? 'opacity-100' : 'opacity-0'
|
||||||
|
}`}
|
||||||
|
onLoad={() => advance(2)}
|
||||||
|
onClick={onClick}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
22
frontend/src/context/GenerationContext.tsx
Normal file
22
frontend/src/context/GenerationContext.tsx
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import React, { createContext, useContext, useState, useCallback } from 'react'
|
||||||
|
|
||||||
|
interface Value {
|
||||||
|
pendingCount: number
|
||||||
|
addPending: (n?: number) => void
|
||||||
|
decrementPending: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const Ctx = createContext<Value>({
|
||||||
|
pendingCount: 0,
|
||||||
|
addPending: () => {},
|
||||||
|
decrementPending: () => {},
|
||||||
|
})
|
||||||
|
|
||||||
|
export function GenerationProvider({ children }: { children: React.ReactNode }) {
|
||||||
|
const [pendingCount, setPendingCount] = useState(0)
|
||||||
|
const addPending = useCallback((n = 1) => setPendingCount(c => c + n), [])
|
||||||
|
const decrementPending = useCallback(() => setPendingCount(c => Math.max(0, c - 1)), [])
|
||||||
|
return <Ctx.Provider value={{ pendingCount, addPending, decrementPending }}>{children}</Ctx.Provider>
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGeneration = () => useContext(Ctx)
|
||||||
26
frontend/src/hooks/useAuth.ts
Normal file
26
frontend/src/hooks/useAuth.ts
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { authMe, authLogout } from '../api/client'
|
||||||
|
|
||||||
|
export function useAuth() {
|
||||||
|
const qc = useQueryClient()
|
||||||
|
const { data, isLoading, error } = useQuery({
|
||||||
|
queryKey: ['auth', 'me'],
|
||||||
|
queryFn: authMe,
|
||||||
|
retry: false,
|
||||||
|
staleTime: 60_000,
|
||||||
|
})
|
||||||
|
|
||||||
|
const logout = async () => {
|
||||||
|
await authLogout()
|
||||||
|
qc.clear()
|
||||||
|
window.location.href = '/login'
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
user: data ?? null,
|
||||||
|
isLoading,
|
||||||
|
isAuthenticated: !!data,
|
||||||
|
isAdmin: data?.admin ?? false,
|
||||||
|
logout,
|
||||||
|
}
|
||||||
|
}
|
||||||
71
frontend/src/hooks/useStatus.ts
Normal file
71
frontend/src/hooks/useStatus.ts
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import { useState, useCallback } from 'react'
|
||||||
|
import { useWebSocket } from './useWebSocket'
|
||||||
|
|
||||||
|
export interface StatusSnapshot {
|
||||||
|
bot?: { latency_ms: number; uptime: string }
|
||||||
|
comfy?: {
|
||||||
|
server: string
|
||||||
|
reachable?: boolean
|
||||||
|
queue_pending: number
|
||||||
|
queue_running: number
|
||||||
|
workflow_loaded: boolean
|
||||||
|
last_seed: number | null
|
||||||
|
total_generated: number
|
||||||
|
}
|
||||||
|
overrides?: Record<string, unknown>
|
||||||
|
service?: { state: string }
|
||||||
|
upload?: { configured: boolean; running: boolean; total_ok: number; total_fail: number }
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GenerationResult {
|
||||||
|
prompt_id: string
|
||||||
|
seed: number | null
|
||||||
|
image_count: number
|
||||||
|
video_count: number
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UseStatusOptions {
|
||||||
|
enabled?: boolean
|
||||||
|
onGenerationComplete?: (result: GenerationResult) => void
|
||||||
|
onGenerationError?: (error: { prompt_id: string | null; error: string }) => void
|
||||||
|
onNodeExecuting?: (node: string, promptId: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useStatus({
|
||||||
|
enabled = true,
|
||||||
|
onGenerationComplete,
|
||||||
|
onGenerationError,
|
||||||
|
onNodeExecuting,
|
||||||
|
}: UseStatusOptions) {
|
||||||
|
const [status, setStatus] = useState<StatusSnapshot>({})
|
||||||
|
const [executingNode, setExecutingNode] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const handleMessage = useCallback(
|
||||||
|
(msg: { type: string; data: unknown; ts: number }) => {
|
||||||
|
if (msg.type === 'status_snapshot') {
|
||||||
|
setStatus(msg.data as StatusSnapshot)
|
||||||
|
} else if (msg.type === 'node_executing') {
|
||||||
|
const d = msg.data as { node: string; prompt_id: string }
|
||||||
|
setExecutingNode(d.node)
|
||||||
|
onNodeExecuting?.(d.node, d.prompt_id)
|
||||||
|
} else if (msg.type === 'generation_complete') {
|
||||||
|
setExecutingNode(null)
|
||||||
|
onGenerationComplete?.(msg.data as GenerationResult)
|
||||||
|
} else if (msg.type === 'generation_error') {
|
||||||
|
setExecutingNode(null)
|
||||||
|
onGenerationError?.(msg.data as { prompt_id: string | null; error: string })
|
||||||
|
} else if (msg.type === 'server_state') {
|
||||||
|
const d = msg.data as { state: string; http_reachable: boolean }
|
||||||
|
setStatus(prev => ({
|
||||||
|
...prev,
|
||||||
|
service: { state: d.state },
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[onGenerationComplete, onGenerationError, onNodeExecuting],
|
||||||
|
)
|
||||||
|
|
||||||
|
useWebSocket({ onMessage: handleMessage, enabled })
|
||||||
|
|
||||||
|
return { status, executingNode }
|
||||||
|
}
|
||||||
55
frontend/src/hooks/useWebSocket.ts
Normal file
55
frontend/src/hooks/useWebSocket.ts
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import { useEffect, useRef, useCallback } from 'react'
|
||||||
|
|
||||||
|
interface WSOptions {
|
||||||
|
onMessage: (data: { type: string; data: unknown; ts: number }) => void
|
||||||
|
enabled?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Reconnecting WebSocket hook with exponential backoff. Auth via ttb_session cookie. */
|
||||||
|
export function useWebSocket({ onMessage, enabled = true }: WSOptions) {
|
||||||
|
const wsRef = useRef<WebSocket | null>(null)
|
||||||
|
const backoffRef = useRef(1000)
|
||||||
|
const timerRef = useRef<ReturnType<typeof setTimeout> | null>(null)
|
||||||
|
const onMessageRef = useRef(onMessage)
|
||||||
|
onMessageRef.current = onMessage
|
||||||
|
|
||||||
|
const connect = useCallback(() => {
|
||||||
|
if (!enabled) return
|
||||||
|
const proto = window.location.protocol === 'https:' ? 'wss' : 'ws'
|
||||||
|
const url = `${proto}://${window.location.host}/ws`
|
||||||
|
const ws = new WebSocket(url)
|
||||||
|
wsRef.current = ws
|
||||||
|
|
||||||
|
ws.onmessage = (ev) => {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(ev.data)
|
||||||
|
if (parsed.type !== 'ping') {
|
||||||
|
onMessageRef.current(parsed)
|
||||||
|
}
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
|
|
||||||
|
ws.onopen = () => {
|
||||||
|
backoffRef.current = 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
ws.onclose = () => {
|
||||||
|
if (enabled) {
|
||||||
|
timerRef.current = setTimeout(() => {
|
||||||
|
backoffRef.current = Math.min(backoffRef.current * 2, 30_000)
|
||||||
|
connect()
|
||||||
|
}, backoffRef.current)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ws.onerror = () => ws.close()
|
||||||
|
}, [enabled])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
connect()
|
||||||
|
return () => {
|
||||||
|
enabled && (wsRef.current?.close())
|
||||||
|
if (timerRef.current) clearTimeout(timerRef.current)
|
||||||
|
}
|
||||||
|
}, [connect, enabled])
|
||||||
|
}
|
||||||
12
frontend/src/index.css
Normal file
12
frontend/src/index.css
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
@tailwind base;
|
||||||
|
@tailwind components;
|
||||||
|
@tailwind utilities;
|
||||||
|
|
||||||
|
:root {
|
||||||
|
color-scheme: light dark;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
@apply bg-gray-50 text-gray-900 dark:bg-gray-900 dark:text-gray-100;
|
||||||
|
font-family: system-ui, -apple-system, sans-serif;
|
||||||
|
}
|
||||||
19
frontend/src/main.tsx
Normal file
19
frontend/src/main.tsx
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import ReactDOM from 'react-dom/client'
|
||||||
|
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||||
|
import App from './App'
|
||||||
|
import './index.css'
|
||||||
|
|
||||||
|
const queryClient = new QueryClient({
|
||||||
|
defaultOptions: {
|
||||||
|
queries: { retry: 1, staleTime: 5000 },
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||||
|
<React.StrictMode>
|
||||||
|
<QueryClientProvider client={queryClient}>
|
||||||
|
<App />
|
||||||
|
</QueryClientProvider>
|
||||||
|
</React.StrictMode>,
|
||||||
|
)
|
||||||
113
frontend/src/pages/AdminPage.tsx
Normal file
113
frontend/src/pages/AdminPage.tsx
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import React, { useState } from 'react'
|
||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { adminListTokens, adminCreateToken, adminRevokeToken } from '../api/client'
|
||||||
|
|
||||||
|
export default function AdminPage() {
|
||||||
|
const qc = useQueryClient()
|
||||||
|
const [label, setLabel] = useState('')
|
||||||
|
const [isAdmin, setIsAdmin] = useState(false)
|
||||||
|
const [newToken, setNewToken] = useState<string | null>(null)
|
||||||
|
const [createError, setCreateError] = useState('')
|
||||||
|
|
||||||
|
const { data: tokens = [], isLoading } = useQuery({
|
||||||
|
queryKey: ['admin', 'tokens'],
|
||||||
|
queryFn: adminListTokens,
|
||||||
|
})
|
||||||
|
|
||||||
|
const createMut = useMutation({
|
||||||
|
mutationFn: ({ label, admin }: { label: string; admin: boolean }) => adminCreateToken(label, admin),
|
||||||
|
onSuccess: (res) => {
|
||||||
|
qc.invalidateQueries({ queryKey: ['admin', 'tokens'] })
|
||||||
|
setNewToken(res.token)
|
||||||
|
setLabel('')
|
||||||
|
setIsAdmin(false)
|
||||||
|
setCreateError('')
|
||||||
|
},
|
||||||
|
onError: (err) => setCreateError(err instanceof Error ? err.message : String(err)),
|
||||||
|
})
|
||||||
|
|
||||||
|
const revokeMut = useMutation({
|
||||||
|
mutationFn: (id: string) => adminRevokeToken(id),
|
||||||
|
onSuccess: () => qc.invalidateQueries({ queryKey: ['admin', 'tokens'] }),
|
||||||
|
})
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-xl mx-auto space-y-6">
|
||||||
|
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Admin — Token Management</h1>
|
||||||
|
|
||||||
|
{/* Create token */}
|
||||||
|
<div className="bg-white dark:bg-gray-800 rounded border border-gray-200 dark:border-gray-700 p-4 space-y-3">
|
||||||
|
<p className="text-sm font-medium text-gray-700 dark:text-gray-300">Create invite token</p>
|
||||||
|
<div className="flex flex-col sm:flex-row gap-2">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={label}
|
||||||
|
onChange={e => setLabel(e.target.value)}
|
||||||
|
placeholder="Label (e.g. alice)"
|
||||||
|
className="flex-1 border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
<label className="flex items-center gap-1 text-sm text-gray-600 dark:text-gray-400 cursor-pointer">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={isAdmin}
|
||||||
|
onChange={e => setIsAdmin(e.target.checked)}
|
||||||
|
className="rounded"
|
||||||
|
/>
|
||||||
|
Admin
|
||||||
|
</label>
|
||||||
|
<button
|
||||||
|
onClick={() => createMut.mutate({ label, admin: isAdmin })}
|
||||||
|
disabled={!label.trim() || createMut.isPending}
|
||||||
|
className="bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-3 py-2 text-sm font-medium"
|
||||||
|
>
|
||||||
|
Create
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
{createError && <p className="text-red-500 text-sm">{createError}</p>}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* New token display — one-time */}
|
||||||
|
{newToken && (
|
||||||
|
<div className="bg-yellow-50 dark:bg-yellow-900/30 border border-yellow-300 dark:border-yellow-700 rounded p-4 space-y-2">
|
||||||
|
<p className="text-sm font-medium text-yellow-800 dark:text-yellow-200">
|
||||||
|
New token (copy now — shown only once):
|
||||||
|
</p>
|
||||||
|
<code className="block text-xs break-all bg-yellow-100 dark:bg-yellow-900/50 rounded p-2 text-yellow-900 dark:text-yellow-100 select-all">
|
||||||
|
{newToken}
|
||||||
|
</code>
|
||||||
|
<button
|
||||||
|
onClick={() => setNewToken(null)}
|
||||||
|
className="text-xs text-yellow-600 dark:text-yellow-400 hover:underline"
|
||||||
|
>
|
||||||
|
Dismiss
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Token list */}
|
||||||
|
{isLoading ? (
|
||||||
|
<div className="text-sm text-gray-400">Loading tokens…</div>
|
||||||
|
) : tokens.length === 0 ? (
|
||||||
|
<p className="text-sm text-gray-400">No tokens yet.</p>
|
||||||
|
) : (
|
||||||
|
<div className="divide-y divide-gray-200 dark:divide-gray-700 border border-gray-200 dark:border-gray-700 rounded">
|
||||||
|
{tokens.map(t => (
|
||||||
|
<div key={t.id} className="flex items-center justify-between px-3 py-2 text-sm">
|
||||||
|
<div>
|
||||||
|
<span className="font-medium text-gray-700 dark:text-gray-300">{t.label}</span>
|
||||||
|
{t.admin && <span className="ml-1 text-xs text-purple-600 dark:text-purple-400">(admin)</span>}
|
||||||
|
<span className="ml-2 text-xs text-gray-400">{new Date(t.created_at).toLocaleDateString()}</span>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={() => revokeMut.mutate(t.id)}
|
||||||
|
className="text-xs text-red-500 hover:text-red-700 dark:hover:text-red-400"
|
||||||
|
>
|
||||||
|
Revoke
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
250
frontend/src/pages/GeneratePage.tsx
Normal file
250
frontend/src/pages/GeneratePage.tsx
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
import React, { useState, useCallback } from 'react'
|
||||||
|
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { generate, workflowGen, listPresets, loadPreset } from '../api/client'
|
||||||
|
import { useStatus, GenerationResult } from '../hooks/useStatus'
|
||||||
|
import { useAuth } from '../hooks/useAuth'
|
||||||
|
import { useGeneration } from '../context/GenerationContext'
|
||||||
|
import DynamicWorkflowForm from '../components/DynamicWorkflowForm'
|
||||||
|
|
||||||
|
interface Notification {
|
||||||
|
id: number
|
||||||
|
type: 'success' | 'error' | 'info'
|
||||||
|
msg: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function GeneratePage() {
|
||||||
|
const { user } = useAuth()
|
||||||
|
const { pendingCount, addPending } = useGeneration()
|
||||||
|
const qc = useQueryClient()
|
||||||
|
const [generating, setGenerating] = useState(false)
|
||||||
|
const [notifications, setNotifications] = useState<Notification[]>([])
|
||||||
|
const [executingNodeDisplay, setExecutingNodeDisplay] = useState<string | null>(null)
|
||||||
|
const [lastSeed, setLastSeed] = useState<number | null>(null)
|
||||||
|
const [mode, setMode] = useState<'workflow' | 'prompt'>('workflow')
|
||||||
|
const [prompt, setPrompt] = useState('')
|
||||||
|
const [negPrompt, setNegPrompt] = useState('')
|
||||||
|
const [promptCount, setPromptCount] = useState(1)
|
||||||
|
const [loadingPreset, setLoadingPreset] = useState(false)
|
||||||
|
|
||||||
|
const { data: presetsData } = useQuery({ queryKey: ['presets'], queryFn: listPresets })
|
||||||
|
|
||||||
|
const addNotif = useCallback((type: Notification['type'], msg: string, ttl = 8000) => {
|
||||||
|
const id = Date.now()
|
||||||
|
setNotifications(n => [...n, { id, type, msg }])
|
||||||
|
setTimeout(() => setNotifications(n => n.filter(x => x.id !== id)), ttl)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const handlePresetLoad = useCallback(async (name: string) => {
|
||||||
|
if (!name) return
|
||||||
|
setLoadingPreset(true)
|
||||||
|
try {
|
||||||
|
await loadPreset(name)
|
||||||
|
qc.invalidateQueries({ queryKey: ['state'] })
|
||||||
|
qc.invalidateQueries({ queryKey: ['workflowInputs'] })
|
||||||
|
addNotif('info', `Loaded preset: ${name}`, 3000)
|
||||||
|
} catch (err: unknown) {
|
||||||
|
addNotif('error', `Failed to load preset: ${err instanceof Error ? err.message : String(err)}`)
|
||||||
|
} finally {
|
||||||
|
setLoadingPreset(false)
|
||||||
|
}
|
||||||
|
}, [qc, addNotif])
|
||||||
|
|
||||||
|
const dismissNotif = useCallback((id: number) => {
|
||||||
|
setNotifications(n => n.filter(x => x.id !== id))
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const onGenerationComplete = useCallback((r: GenerationResult) => {
|
||||||
|
setExecutingNodeDisplay(null)
|
||||||
|
setLastSeed(r.seed)
|
||||||
|
addNotif('success', `Done — seed: ${r.seed ?? 'unknown'} · ${r.image_count} image(s) · ${r.video_count} video(s)`)
|
||||||
|
}, [addNotif])
|
||||||
|
|
||||||
|
const onGenerationError = useCallback((e: { prompt_id: string | null; error: string }) => {
|
||||||
|
setExecutingNodeDisplay(null)
|
||||||
|
addNotif('error', e.error)
|
||||||
|
}, [addNotif])
|
||||||
|
|
||||||
|
const onNodeExecuting = useCallback((node: string) => {
|
||||||
|
setExecutingNodeDisplay(node)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const { status } = useStatus({
|
||||||
|
enabled: !!user,
|
||||||
|
onGenerationComplete,
|
||||||
|
onGenerationError,
|
||||||
|
onNodeExecuting,
|
||||||
|
})
|
||||||
|
|
||||||
|
const handleWorkflowGenerate = async (overrides: Record<string, unknown>, count: number = 1) => {
|
||||||
|
setGenerating(true)
|
||||||
|
document.title = '⏳ Queued… | ComfyUI Bot'
|
||||||
|
try {
|
||||||
|
const res = await workflowGen({ count, overrides })
|
||||||
|
setGenerating(false)
|
||||||
|
addPending(res.count ?? count)
|
||||||
|
addNotif('info', `Queued ${res.count ?? count} generation(s) at position ${res.queue_position}`, 3000)
|
||||||
|
} catch (err: unknown) {
|
||||||
|
setGenerating(false)
|
||||||
|
addNotif('error', err instanceof Error ? err.message : String(err))
|
||||||
|
document.title = 'ComfyUI Bot'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handlePromptGenerate = async () => {
|
||||||
|
const n = Math.max(1, Math.min(20, promptCount))
|
||||||
|
setGenerating(true)
|
||||||
|
document.title = '⏳ Queued… | ComfyUI Bot'
|
||||||
|
try {
|
||||||
|
let lastPos = 0
|
||||||
|
for (let i = 0; i < n; i++) {
|
||||||
|
const res = await generate({ prompt, negative_prompt: negPrompt })
|
||||||
|
lastPos = res.queue_position
|
||||||
|
}
|
||||||
|
setGenerating(false)
|
||||||
|
addPending(n)
|
||||||
|
addNotif('info', `Queued ${n} generation(s) at position ${lastPos}`, 3000)
|
||||||
|
} catch (err: unknown) {
|
||||||
|
setGenerating(false)
|
||||||
|
addNotif('error', err instanceof Error ? err.message : String(err))
|
||||||
|
document.title = 'ComfyUI Bot'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const queuePending = (status.comfy?.queue_pending ?? 0)
|
||||||
|
const queueRunning = (status.comfy?.queue_running ?? 0)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-2xl mx-auto space-y-6">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Generate</h1>
|
||||||
|
<div className="text-xs text-gray-400">
|
||||||
|
ComfyUI: {queueRunning} running, {queuePending} pending
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Mode toggle */}
|
||||||
|
<div className="flex flex-wrap gap-2 text-sm">
|
||||||
|
<button
|
||||||
|
onClick={() => setMode('workflow')}
|
||||||
|
className={`px-3 py-1.5 rounded ${mode === 'workflow' ? 'bg-blue-600 text-white' : 'bg-gray-200 dark:bg-gray-700 text-gray-700 dark:text-gray-300'}`}
|
||||||
|
>
|
||||||
|
Workflow mode
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => setMode('prompt')}
|
||||||
|
className={`px-3 py-1.5 rounded ${mode === 'prompt' ? 'bg-blue-600 text-white' : 'bg-gray-200 dark:bg-gray-700 text-gray-700 dark:text-gray-300'}`}
|
||||||
|
>
|
||||||
|
Prompt mode
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Quick-load preset */}
|
||||||
|
{(presetsData?.presets ?? []).length > 0 && (
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<select
|
||||||
|
defaultValue=""
|
||||||
|
disabled={loadingPreset}
|
||||||
|
onChange={e => {
|
||||||
|
const v = e.target.value
|
||||||
|
e.target.value = ''
|
||||||
|
if (v) handlePresetLoad(v)
|
||||||
|
}}
|
||||||
|
className="flex-1 border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-700 dark:text-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-500 disabled:opacity-50"
|
||||||
|
>
|
||||||
|
<option value="" disabled>Load a preset…</option>
|
||||||
|
{(presetsData?.presets ?? []).map(p => (
|
||||||
|
<option key={p.name} value={p.name}>
|
||||||
|
{p.name}{p.description ? ` — ${p.description}` : ''}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
{loadingPreset && <span className="text-xs text-gray-400">Loading…</span>}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Progress banner */}
|
||||||
|
{(pendingCount > 0 || executingNodeDisplay) && (
|
||||||
|
<div className="bg-blue-50 dark:bg-blue-900/30 border border-blue-200 dark:border-blue-700 rounded p-3 text-sm text-blue-700 dark:text-blue-300">
|
||||||
|
{pendingCount > 0 && `${pendingCount} generation(s) in progress`}
|
||||||
|
{executingNodeDisplay && ` · running: ${executingNodeDisplay}`}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Form */}
|
||||||
|
{mode === 'workflow' ? (
|
||||||
|
<DynamicWorkflowForm
|
||||||
|
onGenerate={handleWorkflowGenerate}
|
||||||
|
lastSeed={lastSeed}
|
||||||
|
generating={generating}
|
||||||
|
userLabel={user?.label}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-1">Prompt</label>
|
||||||
|
<textarea
|
||||||
|
rows={3}
|
||||||
|
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 resize-y focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
value={prompt}
|
||||||
|
onChange={e => setPrompt(e.target.value)}
|
||||||
|
placeholder="Describe what you want to generate"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-1">Negative prompt</label>
|
||||||
|
<textarea
|
||||||
|
rows={2}
|
||||||
|
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 resize-y focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
value={negPrompt}
|
||||||
|
onChange={e => setNegPrompt(e.target.value)}
|
||||||
|
placeholder="What to avoid"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2 items-center">
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
min={1}
|
||||||
|
max={20}
|
||||||
|
value={promptCount}
|
||||||
|
onChange={e => setPromptCount(Math.max(1, Math.min(20, Number(e.target.value))))}
|
||||||
|
className="w-16 border border-gray-300 dark:border-gray-600 rounded px-2 py-2 text-sm text-center bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||||
|
title="Number of generations to queue"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
onClick={handlePromptGenerate}
|
||||||
|
disabled={generating || !prompt.trim()}
|
||||||
|
className="flex-1 bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-4 py-2 text-sm font-semibold transition-colors"
|
||||||
|
>
|
||||||
|
{generating ? '⏳ Queuing…' : promptCount > 1 ? `Generate ×${promptCount}` : 'Generate'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Toast notification stack */}
|
||||||
|
<div className="fixed bottom-4 right-4 z-50 space-y-2 max-w-sm">
|
||||||
|
{notifications.map(n => (
|
||||||
|
<div
|
||||||
|
key={n.id}
|
||||||
|
className={`flex items-start gap-2 rounded border p-3 text-sm shadow-md ${
|
||||||
|
n.type === 'success'
|
||||||
|
? 'bg-green-50 dark:bg-green-900/40 border-green-200 dark:border-green-700 text-green-700 dark:text-green-300'
|
||||||
|
: n.type === 'error'
|
||||||
|
? 'bg-red-50 dark:bg-red-900/40 border-red-200 dark:border-red-700 text-red-700 dark:text-red-300'
|
||||||
|
: 'bg-blue-50 dark:bg-blue-900/40 border-blue-200 dark:border-blue-700 text-blue-700 dark:text-blue-300'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<span className="flex-1">{n.msg}</span>
|
||||||
|
<button
|
||||||
|
onClick={() => dismissNotif(n.id)}
|
||||||
|
className="flex-none opacity-60 hover:opacity-100 leading-none"
|
||||||
|
aria-label="Dismiss"
|
||||||
|
>
|
||||||
|
✕
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
382
frontend/src/pages/HistoryPage.tsx
Normal file
382
frontend/src/pages/HistoryPage.tsx
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
import React, { useState, useEffect, useRef } from 'react'
|
||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { getHistory, createHistoryShare, revokeHistoryShare, savePresetFromHistory } from '../api/client'
|
||||||
|
import { useAuth } from '../hooks/useAuth'
|
||||||
|
|
||||||
|
interface HistoryRow {
|
||||||
|
id: number
|
||||||
|
prompt_id: string
|
||||||
|
source: string
|
||||||
|
user_label?: string
|
||||||
|
overrides: Record<string, unknown>
|
||||||
|
seed?: number
|
||||||
|
file_paths?: string[]
|
||||||
|
created_at: string
|
||||||
|
share_token?: string | null
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Debounce a value by `delay` ms. */
|
||||||
|
function useDebounce<T>(value: T, delay: number): T {
|
||||||
|
const [debounced, setDebounced] = useState(value)
|
||||||
|
useEffect(() => {
|
||||||
|
const t = setTimeout(() => setDebounced(value), delay)
|
||||||
|
return () => clearTimeout(t)
|
||||||
|
}, [value, delay])
|
||||||
|
return debounced
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function HistoryPage() {
|
||||||
|
const [lightbox, setLightbox] = useState<string | null>(null)
|
||||||
|
const [expandedId, setExpandedId] = useState<string | null>(null)
|
||||||
|
const [searchInput, setSearchInput] = useState('')
|
||||||
|
const { user } = useAuth()
|
||||||
|
|
||||||
|
const debouncedQ = useDebounce(searchInput.trim(), 300)
|
||||||
|
|
||||||
|
const { data, isLoading } = useQuery({
|
||||||
|
queryKey: ['history', debouncedQ],
|
||||||
|
queryFn: () => getHistory(debouncedQ || undefined),
|
||||||
|
refetchInterval: 10_000,
|
||||||
|
})
|
||||||
|
|
||||||
|
const rows: HistoryRow[] = ((data?.history ?? []) as unknown as HistoryRow[])
|
||||||
|
|
||||||
|
const isSearching = searchInput.trim() !== debouncedQ
|
||||||
|
|
||||||
|
if (isLoading && !searchInput) return <div className="text-sm text-gray-400">Loading history…</div>
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">History</h1>
|
||||||
|
|
||||||
|
{/* Search */}
|
||||||
|
<div className="relative">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={searchInput}
|
||||||
|
onChange={e => setSearchInput(e.target.value)}
|
||||||
|
placeholder="Search by prompt, checkpoint, seed…"
|
||||||
|
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
{searchInput && (
|
||||||
|
<button
|
||||||
|
onClick={() => setSearchInput('')}
|
||||||
|
className="absolute right-2 top-1/2 -translate-y-1/2 text-gray-400 hover:text-gray-600 dark:hover:text-gray-300 text-xs px-1"
|
||||||
|
>
|
||||||
|
✕
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
{isSearching && <p className="text-xs text-gray-400">Searching…</p>}
|
||||||
|
|
||||||
|
{isLoading ? (
|
||||||
|
<div className="text-sm text-gray-400">Loading…</div>
|
||||||
|
) : rows.length === 0 ? (
|
||||||
|
<p className="text-sm text-gray-400">
|
||||||
|
{debouncedQ ? `No results for "${debouncedQ}".` : 'No generation history yet.'}
|
||||||
|
</p>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
{/* Desktop table — hidden on mobile */}
|
||||||
|
<div className="hidden sm:block overflow-x-auto">
|
||||||
|
<table className="w-full text-sm border-collapse">
|
||||||
|
<thead>
|
||||||
|
<tr className="text-left text-xs text-gray-400 border-b border-gray-200 dark:border-gray-700">
|
||||||
|
<th className="pb-2 pr-4">Time</th>
|
||||||
|
<th className="pb-2 pr-4">Source</th>
|
||||||
|
<th className="pb-2 pr-4">User</th>
|
||||||
|
<th className="pb-2 pr-4">Seed</th>
|
||||||
|
<th className="pb-2">Files</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{rows.map(row => (
|
||||||
|
<React.Fragment key={row.prompt_id}>
|
||||||
|
<tr
|
||||||
|
className="border-b border-gray-100 dark:border-gray-800 hover:bg-gray-50 dark:hover:bg-gray-800/50 cursor-pointer"
|
||||||
|
onClick={() => setExpandedId(expandedId === row.prompt_id ? null : row.prompt_id)}
|
||||||
|
>
|
||||||
|
<td className="py-2 pr-4 text-gray-500 dark:text-gray-400 whitespace-nowrap">
|
||||||
|
{new Date(row.created_at).toLocaleString()}
|
||||||
|
</td>
|
||||||
|
<td className="py-2 pr-4">
|
||||||
|
<span className={`text-xs px-1.5 py-0.5 rounded ${row.source === 'web' ? 'bg-blue-100 dark:bg-blue-900 text-blue-700 dark:text-blue-300' : 'bg-purple-100 dark:bg-purple-900 text-purple-700 dark:text-purple-300'}`}>
|
||||||
|
{row.source}
|
||||||
|
</span>
|
||||||
|
</td>
|
||||||
|
<td className="py-2 pr-4 text-gray-600 dark:text-gray-400">{row.user_label ?? '—'}</td>
|
||||||
|
<td className="py-2 pr-4 font-mono text-gray-700 dark:text-gray-300">{row.seed ?? '—'}</td>
|
||||||
|
<td className="py-2 text-gray-500 dark:text-gray-400">{(row.file_paths ?? []).length} file(s)</td>
|
||||||
|
</tr>
|
||||||
|
{expandedId === row.prompt_id && (
|
||||||
|
<tr className="bg-gray-50 dark:bg-gray-800/50">
|
||||||
|
<td colSpan={5} className="px-3 py-3">
|
||||||
|
<ExpandedRow row={row} onLightbox={setLightbox} currentUserLabel={user?.label ?? null} />
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
)}
|
||||||
|
</React.Fragment>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Mobile card list — hidden on sm+ */}
|
||||||
|
<div className="sm:hidden space-y-2">
|
||||||
|
{rows.map(row => {
|
||||||
|
const isExpanded = expandedId === row.prompt_id
|
||||||
|
const dt = new Date(row.created_at)
|
||||||
|
const timeStr = dt.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })
|
||||||
|
const dateStr = dt.toLocaleDateString([], { month: 'short', day: 'numeric' })
|
||||||
|
return (
|
||||||
|
<div key={row.prompt_id} className="border border-gray-200 dark:border-gray-700 rounded bg-white dark:bg-gray-800">
|
||||||
|
<button
|
||||||
|
className="w-full text-left px-3 py-2 space-y-1"
|
||||||
|
onClick={() => setExpandedId(isExpanded ? null : row.prompt_id)}
|
||||||
|
>
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<span className="text-xs text-gray-500 dark:text-gray-400">{timeStr} · {dateStr}</span>
|
||||||
|
<span className={`text-xs px-1.5 py-0.5 rounded ${row.source === 'web' ? 'bg-blue-100 dark:bg-blue-900 text-blue-700 dark:text-blue-300' : 'bg-purple-100 dark:bg-purple-900 text-purple-700 dark:text-purple-300'}`}>
|
||||||
|
{row.source}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-3 text-xs text-gray-600 dark:text-gray-400">
|
||||||
|
<span>User: {row.user_label ?? '—'}</span>
|
||||||
|
<span>Seed: {row.seed ?? '—'}</span>
|
||||||
|
<span>{(row.file_paths ?? []).length} file(s)</span>
|
||||||
|
<span className="ml-auto text-gray-400">{isExpanded ? '▲' : '▼'}</span>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
{isExpanded && (
|
||||||
|
<div className="px-3 pb-3 border-t border-gray-100 dark:border-gray-700 pt-2">
|
||||||
|
<ExpandedRow row={row} onLightbox={setLightbox} currentUserLabel={user?.label ?? null} />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Lightbox */}
|
||||||
|
{lightbox && (
|
||||||
|
<div
|
||||||
|
className="fixed inset-0 bg-black/80 flex items-center justify-center z-50"
|
||||||
|
onClick={() => setLightbox(null)}
|
||||||
|
>
|
||||||
|
<button
|
||||||
|
className="absolute top-3 right-3 text-white text-2xl leading-none hover:text-gray-300"
|
||||||
|
onClick={() => setLightbox(null)}
|
||||||
|
aria-label="Close"
|
||||||
|
>
|
||||||
|
✕
|
||||||
|
</button>
|
||||||
|
<img src={lightbox} alt="preview" className="max-w-[90vw] max-h-[90vh] object-contain rounded" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function ExpandedRow({
|
||||||
|
row,
|
||||||
|
onLightbox,
|
||||||
|
currentUserLabel,
|
||||||
|
}: {
|
||||||
|
row: HistoryRow
|
||||||
|
onLightbox: (url: string) => void
|
||||||
|
currentUserLabel: string | null
|
||||||
|
}) {
|
||||||
|
const qc = useQueryClient()
|
||||||
|
const [copied, setCopied] = useState(false)
|
||||||
|
const [showSavePreset, setShowSavePreset] = useState(false)
|
||||||
|
const [presetName, setPresetName] = useState('')
|
||||||
|
const [presetDesc, setPresetDesc] = useState('')
|
||||||
|
const [presetMsg, setPresetMsg] = useState<{ ok: boolean; text: string } | null>(null)
|
||||||
|
const presetNameRef = useRef<HTMLInputElement>(null)
|
||||||
|
|
||||||
|
const { data: imagesData, isLoading } = useQuery({
|
||||||
|
queryKey: ['history', row.prompt_id, 'images'],
|
||||||
|
queryFn: async () => {
|
||||||
|
const res = await fetch(`/api/history/${row.prompt_id}/images`, { credentials: 'include' })
|
||||||
|
if (!res.ok) throw new Error(`${res.status}`)
|
||||||
|
return res.json() as Promise<{ images: Array<{ filename: string; data: string | null; mime_type: string }> }>
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const shareMut = useMutation({
|
||||||
|
mutationFn: () => createHistoryShare(row.prompt_id),
|
||||||
|
onSuccess: () => qc.invalidateQueries({ queryKey: ['history'] }),
|
||||||
|
})
|
||||||
|
|
||||||
|
const revokeMut = useMutation({
|
||||||
|
mutationFn: () => revokeHistoryShare(row.prompt_id),
|
||||||
|
onSuccess: () => qc.invalidateQueries({ queryKey: ['history'] }),
|
||||||
|
})
|
||||||
|
|
||||||
|
const savePresetMut = useMutation({
|
||||||
|
mutationFn: () => savePresetFromHistory(row.prompt_id, presetName.trim(), presetDesc.trim() || undefined),
|
||||||
|
onSuccess: (res) => {
|
||||||
|
qc.invalidateQueries({ queryKey: ['presets'] })
|
||||||
|
setPresetMsg({ ok: true, text: `Saved as preset "${res.name}"` })
|
||||||
|
setPresetName('')
|
||||||
|
setPresetDesc('')
|
||||||
|
},
|
||||||
|
onError: (err) => {
|
||||||
|
setPresetMsg({ ok: false, text: err instanceof Error ? err.message : String(err) })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const shareUrl = row.share_token
|
||||||
|
? `${window.location.origin}/share/${row.share_token}`
|
||||||
|
: null
|
||||||
|
|
||||||
|
const isOwner = currentUserLabel !== null && row.user_label === currentUserLabel
|
||||||
|
|
||||||
|
const handleCopy = () => {
|
||||||
|
if (!shareUrl) return
|
||||||
|
navigator.clipboard.writeText(shareUrl).then(() => {
|
||||||
|
setCopied(true)
|
||||||
|
setTimeout(() => setCopied(false), 2000)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-3">
|
||||||
|
{/* Overrides */}
|
||||||
|
<details>
|
||||||
|
<summary className="text-xs text-gray-400 cursor-pointer">Overrides</summary>
|
||||||
|
<pre className="text-xs bg-gray-100 dark:bg-gray-700 rounded p-2 mt-1 overflow-auto max-h-32 text-gray-600 dark:text-gray-400">
|
||||||
|
{JSON.stringify(row.overrides, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
{/* Images */}
|
||||||
|
{isLoading ? (
|
||||||
|
<p className="text-xs text-gray-400">Loading images…</p>
|
||||||
|
) : (imagesData?.images ?? []).length > 0 ? (
|
||||||
|
<div className="flex gap-2 flex-wrap">
|
||||||
|
{(imagesData?.images ?? []).map((img, i) => {
|
||||||
|
const isVideo = img.mime_type.startsWith('video/')
|
||||||
|
if (isVideo) {
|
||||||
|
const videoSrc = `/api/history/${row.prompt_id}/file/${encodeURIComponent(img.filename)}`
|
||||||
|
return (
|
||||||
|
<video
|
||||||
|
key={i}
|
||||||
|
src={videoSrc}
|
||||||
|
controls
|
||||||
|
className="rounded max-h-40"
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
const src = `data:${img.mime_type};base64,${img.data}`
|
||||||
|
return (
|
||||||
|
<img
|
||||||
|
key={i}
|
||||||
|
src={src}
|
||||||
|
alt={img.filename}
|
||||||
|
className="rounded max-h-40 cursor-pointer border border-gray-200 dark:border-gray-700"
|
||||||
|
onClick={() => onLightbox(src)}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<p className="text-xs text-gray-400">Files not available (may have been moved or deleted).</p>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Owner-only actions */}
|
||||||
|
{isOwner && (
|
||||||
|
<div className="pt-1 border-t border-gray-100 dark:border-gray-700 space-y-2">
|
||||||
|
{/* Share section */}
|
||||||
|
{shareUrl ? (
|
||||||
|
<div className="space-y-1">
|
||||||
|
<p className="text-xs text-gray-500 dark:text-gray-400">Share link</p>
|
||||||
|
<div className="flex items-center gap-2 flex-wrap">
|
||||||
|
<code className="text-xs bg-gray-100 dark:bg-gray-700 rounded px-2 py-1 text-gray-700 dark:text-gray-300 break-all select-all flex-1 min-w-0">
|
||||||
|
{shareUrl}
|
||||||
|
</code>
|
||||||
|
<button
|
||||||
|
onClick={handleCopy}
|
||||||
|
className="shrink-0 text-xs px-2 py-1 rounded bg-blue-100 dark:bg-blue-900 text-blue-700 dark:text-blue-300 hover:bg-blue-200 dark:hover:bg-blue-800 transition-colors"
|
||||||
|
>
|
||||||
|
{copied ? 'Copied!' : 'Copy'}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => revokeMut.mutate()}
|
||||||
|
disabled={revokeMut.isPending}
|
||||||
|
className="shrink-0 text-xs px-2 py-1 rounded bg-red-100 dark:bg-red-900 text-red-700 dark:text-red-300 hover:bg-red-200 dark:hover:bg-red-800 disabled:opacity-50 transition-colors"
|
||||||
|
>
|
||||||
|
{revokeMut.isPending ? 'Revoking…' : 'Revoke'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<button
|
||||||
|
onClick={() => shareMut.mutate()}
|
||||||
|
disabled={shareMut.isPending}
|
||||||
|
className="text-xs px-2 py-1 rounded bg-gray-100 dark:bg-gray-700 text-gray-600 dark:text-gray-400 hover:bg-gray-200 dark:hover:bg-gray-600 disabled:opacity-50 transition-colors"
|
||||||
|
>
|
||||||
|
{shareMut.isPending ? 'Creating link…' : 'Share'}
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Save-as-preset section */}
|
||||||
|
{!showSavePreset ? (
|
||||||
|
<button
|
||||||
|
onClick={() => { setShowSavePreset(true); setPresetMsg(null); setTimeout(() => presetNameRef.current?.focus(), 50) }}
|
||||||
|
className="text-xs px-2 py-1 rounded bg-gray-100 dark:bg-gray-700 text-gray-600 dark:text-gray-400 hover:bg-gray-200 dark:hover:bg-gray-600 transition-colors"
|
||||||
|
>
|
||||||
|
Save as preset
|
||||||
|
</button>
|
||||||
|
) : (
|
||||||
|
<div className="space-y-1.5">
|
||||||
|
<p className="text-xs text-gray-500 dark:text-gray-400">Save overrides as preset</p>
|
||||||
|
<p className="text-xs text-amber-600 dark:text-amber-400">
|
||||||
|
Note: workflow template is not saved — load it separately before using this preset.
|
||||||
|
</p>
|
||||||
|
<div className="flex gap-2 flex-wrap">
|
||||||
|
<input
|
||||||
|
ref={presetNameRef}
|
||||||
|
type="text"
|
||||||
|
value={presetName}
|
||||||
|
onChange={e => setPresetName(e.target.value)}
|
||||||
|
placeholder="Preset name"
|
||||||
|
className="flex-1 min-w-0 border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-xs bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={presetDesc}
|
||||||
|
onChange={e => setPresetDesc(e.target.value)}
|
||||||
|
placeholder="Description (optional)"
|
||||||
|
className="flex-1 min-w-0 border border-gray-300 dark:border-gray-600 rounded px-2 py-1 text-xs bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-1 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<button
|
||||||
|
onClick={() => savePresetMut.mutate()}
|
||||||
|
disabled={!presetName.trim() || savePresetMut.isPending}
|
||||||
|
className="text-xs px-2 py-1 rounded bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white transition-colors"
|
||||||
|
>
|
||||||
|
{savePresetMut.isPending ? 'Saving…' : 'Save'}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => { setShowSavePreset(false); setPresetMsg(null); setPresetName(''); setPresetDesc('') }}
|
||||||
|
className="text-xs px-2 py-1 rounded bg-gray-100 dark:bg-gray-700 text-gray-600 dark:text-gray-400 hover:bg-gray-200 dark:hover:bg-gray-600 transition-colors"
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
{presetMsg && (
|
||||||
|
<p className={`text-xs ${presetMsg.ok ? 'text-green-600 dark:text-green-400' : 'text-red-500'}`}>
|
||||||
|
{presetMsg.text}
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
216
frontend/src/pages/InputImagesPage.tsx
Normal file
216
frontend/src/pages/InputImagesPage.tsx
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
import React, { useRef, useState } from 'react'
|
||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import {
|
||||||
|
listInputs,
|
||||||
|
uploadInput,
|
||||||
|
activateInput,
|
||||||
|
deleteInput,
|
||||||
|
getInputImage,
|
||||||
|
getInputThumb,
|
||||||
|
getInputMid,
|
||||||
|
getWorkflowInputs,
|
||||||
|
getState,
|
||||||
|
InputImage,
|
||||||
|
} from '../api/client'
|
||||||
|
import LazyImage from '../components/LazyImage'
|
||||||
|
|
||||||
|
export default function InputImagesPage() {
|
||||||
|
const qc = useQueryClient()
|
||||||
|
const fileRef = useRef<HTMLInputElement>(null)
|
||||||
|
const [uploading, setUploading] = useState(false)
|
||||||
|
const [uploadSlot, setUploadSlot] = useState<string | null>(null)
|
||||||
|
const [lightbox, setLightbox] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const { data: images = [], isLoading } = useQuery({ queryKey: ['inputs'], queryFn: listInputs })
|
||||||
|
const { data: inputsData } = useQuery({ queryKey: ['workflow', 'inputs'], queryFn: getWorkflowInputs })
|
||||||
|
const { data: stateData } = useQuery({ queryKey: ['state'], queryFn: getState })
|
||||||
|
|
||||||
|
const imageSlots = inputsData
|
||||||
|
? [...inputsData.common, ...inputsData.advanced].filter(i => i.input_type === 'image')
|
||||||
|
: []
|
||||||
|
|
||||||
|
const deleteMut = useMutation({
|
||||||
|
mutationFn: (id: number) => deleteInput(id),
|
||||||
|
onSuccess: () => qc.invalidateQueries({ queryKey: ['inputs'] }),
|
||||||
|
})
|
||||||
|
|
||||||
|
const activateMut = useMutation({
|
||||||
|
mutationFn: ({ id, slotKey }: { id: number; slotKey: string }) => activateInput(id, slotKey),
|
||||||
|
onSuccess: () => {
|
||||||
|
qc.invalidateQueries({ queryKey: ['inputs'] })
|
||||||
|
qc.invalidateQueries({ queryKey: ['state'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const handleUpload = async (slotKey: string) => {
|
||||||
|
if (!fileRef.current?.files?.length) return
|
||||||
|
setUploading(true)
|
||||||
|
setUploadSlot(slotKey)
|
||||||
|
try {
|
||||||
|
await uploadInput(fileRef.current.files[0], slotKey)
|
||||||
|
qc.invalidateQueries({ queryKey: ['inputs'] })
|
||||||
|
qc.invalidateQueries({ queryKey: ['state'] })
|
||||||
|
if (fileRef.current) fileRef.current.value = ''
|
||||||
|
} finally {
|
||||||
|
setUploading(false)
|
||||||
|
setUploadSlot(null)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const activeForSlot = (slotKey: string): string => {
|
||||||
|
const st = stateData as Record<string, unknown> | undefined
|
||||||
|
return String(st?.[slotKey] ?? '')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isLoading) return <div className="text-sm text-gray-400">Loading images…</div>
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-6">
|
||||||
|
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Input Images</h1>
|
||||||
|
|
||||||
|
{/* Per-slot sections */}
|
||||||
|
{imageSlots.length > 0 ? (
|
||||||
|
imageSlots.map(slot => {
|
||||||
|
const activeFilename = activeForSlot(slot.key)
|
||||||
|
return (
|
||||||
|
<section key={slot.key} className="space-y-2">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h2 className="text-sm font-semibold text-gray-700 dark:text-gray-300">
|
||||||
|
{slot.label}
|
||||||
|
{activeFilename && (
|
||||||
|
<span className="ml-2 text-xs font-normal text-blue-500">(active: {activeFilename})</span>
|
||||||
|
)}
|
||||||
|
</h2>
|
||||||
|
<label className="cursor-pointer text-xs bg-blue-600 hover:bg-blue-700 text-white rounded px-2 py-1">
|
||||||
|
{uploading && uploadSlot === slot.key ? 'Uploading…' : 'Upload'}
|
||||||
|
<input
|
||||||
|
ref={fileRef}
|
||||||
|
type="file"
|
||||||
|
accept="image/*"
|
||||||
|
className="hidden"
|
||||||
|
onChange={() => handleUpload(slot.key)}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<ImageGrid
|
||||||
|
images={images}
|
||||||
|
activeFilename={activeFilename}
|
||||||
|
slotKey={slot.key}
|
||||||
|
onActivate={(id) => activateMut.mutate({ id, slotKey: slot.key })}
|
||||||
|
onDelete={(id) => deleteMut.mutate(id)}
|
||||||
|
onLightbox={setLightbox}
|
||||||
|
/>
|
||||||
|
</section>
|
||||||
|
)
|
||||||
|
})
|
||||||
|
) : (
|
||||||
|
/* No workflow loaded — show flat list */
|
||||||
|
<section className="space-y-2">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h2 className="text-sm font-semibold text-gray-700 dark:text-gray-300">All images</h2>
|
||||||
|
<label className="cursor-pointer text-xs bg-blue-600 hover:bg-blue-700 text-white rounded px-2 py-1">
|
||||||
|
{uploading ? 'Uploading…' : 'Upload'}
|
||||||
|
<input
|
||||||
|
ref={fileRef}
|
||||||
|
type="file"
|
||||||
|
accept="image/*"
|
||||||
|
className="hidden"
|
||||||
|
onChange={() => handleUpload('input_image')}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<ImageGrid
|
||||||
|
images={images}
|
||||||
|
activeFilename=""
|
||||||
|
slotKey="input_image"
|
||||||
|
onActivate={(id) => activateMut.mutate({ id, slotKey: 'input_image' })}
|
||||||
|
onDelete={(id) => deleteMut.mutate(id)}
|
||||||
|
onLightbox={setLightbox}
|
||||||
|
/>
|
||||||
|
</section>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{images.length === 0 && (
|
||||||
|
<p className="text-sm text-gray-400">No images yet. Upload one to get started.</p>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Lightbox */}
|
||||||
|
{lightbox && (
|
||||||
|
<div
|
||||||
|
className="fixed inset-0 bg-black/80 flex items-center justify-center z-50"
|
||||||
|
onClick={() => setLightbox(null)}
|
||||||
|
>
|
||||||
|
<button
|
||||||
|
className="absolute top-3 right-3 text-white text-2xl leading-none hover:text-gray-300"
|
||||||
|
onClick={() => setLightbox(null)}
|
||||||
|
aria-label="Close"
|
||||||
|
>
|
||||||
|
✕
|
||||||
|
</button>
|
||||||
|
<img src={lightbox} alt="preview" className="max-w-[90vw] max-h-[90vh] object-contain rounded" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function ImageGrid({
|
||||||
|
images,
|
||||||
|
activeFilename,
|
||||||
|
slotKey,
|
||||||
|
onActivate,
|
||||||
|
onDelete,
|
||||||
|
onLightbox,
|
||||||
|
}: {
|
||||||
|
images: InputImage[]
|
||||||
|
activeFilename: string
|
||||||
|
slotKey: string
|
||||||
|
onActivate: (id: number) => void
|
||||||
|
onDelete: (id: number) => void
|
||||||
|
onLightbox: (url: string) => void
|
||||||
|
}) {
|
||||||
|
if (images.length === 0) return <p className="text-xs text-gray-400">No images.</p>
|
||||||
|
return (
|
||||||
|
<div className="grid grid-cols-3 sm:grid-cols-5 lg:grid-cols-8 gap-2">
|
||||||
|
{images.map(img => {
|
||||||
|
const isActive = img.filename === activeFilename
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={img.id}
|
||||||
|
className={`relative group aspect-square rounded border-2 overflow-hidden cursor-pointer ${
|
||||||
|
isActive ? 'border-blue-500' : 'border-transparent'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<LazyImage
|
||||||
|
thumbSrc={getInputThumb(img.id)}
|
||||||
|
midSrc={getInputMid(img.id)}
|
||||||
|
fullSrc={getInputImage(img.id)}
|
||||||
|
alt={img.filename}
|
||||||
|
className="w-full h-full"
|
||||||
|
onClick={() => onLightbox(getInputImage(img.id))}
|
||||||
|
/>
|
||||||
|
{isActive && (
|
||||||
|
<div className="absolute top-0.5 left-0.5 bg-blue-500 text-white text-[9px] px-1 rounded">active</div>
|
||||||
|
)}
|
||||||
|
<div className="absolute inset-0 bg-black/60 opacity-0 group-hover:opacity-100 [@media(hover:none)]:opacity-100 transition-opacity flex flex-col items-center justify-center gap-1">
|
||||||
|
{!isActive && (
|
||||||
|
<button
|
||||||
|
onClick={() => onActivate(img.id)}
|
||||||
|
className="text-[10px] bg-blue-600 text-white rounded px-1.5 py-0.5 hover:bg-blue-700"
|
||||||
|
>
|
||||||
|
Activate
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
<button
|
||||||
|
onClick={() => onDelete(img.id)}
|
||||||
|
className="text-[10px] bg-red-600 text-white rounded px-1.5 py-0.5 hover:bg-red-700"
|
||||||
|
>
|
||||||
|
Delete
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
76
frontend/src/pages/LoginPage.tsx
Normal file
76
frontend/src/pages/LoginPage.tsx
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import React, { useState } from 'react'
|
||||||
|
import { useNavigate } from 'react-router-dom'
|
||||||
|
import { authLogin, adminLogin } from '../api/client'
|
||||||
|
import { useQueryClient } from '@tanstack/react-query'
|
||||||
|
|
||||||
|
export default function LoginPage() {
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const qc = useQueryClient()
|
||||||
|
const [token, setToken] = useState('')
|
||||||
|
const [error, setError] = useState('')
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [isAdmin, setIsAdmin] = useState(false)
|
||||||
|
|
||||||
|
const handleSubmit = async (e: React.FormEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
setError('')
|
||||||
|
setLoading(true)
|
||||||
|
try {
|
||||||
|
if (isAdmin) {
|
||||||
|
await adminLogin(token)
|
||||||
|
} else {
|
||||||
|
await authLogin(token)
|
||||||
|
}
|
||||||
|
await qc.invalidateQueries({ queryKey: ['auth', 'me'] })
|
||||||
|
navigate('/generate')
|
||||||
|
} catch (err: unknown) {
|
||||||
|
const msg = err instanceof Error ? err.message : String(err)
|
||||||
|
if (msg.includes('429')) {
|
||||||
|
setError('Too many failed attempts. Please wait 1 hour before trying again.')
|
||||||
|
} else if (msg.includes('401')) {
|
||||||
|
setError('Invalid token or password.')
|
||||||
|
} else {
|
||||||
|
setError(msg)
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen flex items-center justify-center bg-gray-100 dark:bg-gray-900">
|
||||||
|
<div className="bg-white dark:bg-gray-800 shadow rounded-lg p-8 w-full max-w-sm">
|
||||||
|
<h1 className="text-xl font-bold mb-6 text-gray-800 dark:text-gray-100">ComfyUI Bot</h1>
|
||||||
|
<form onSubmit={handleSubmit} className="space-y-4">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-1">
|
||||||
|
{isAdmin ? 'Admin password' : 'Invite token'}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
type="password"
|
||||||
|
value={token}
|
||||||
|
onChange={e => setToken(e.target.value)}
|
||||||
|
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
placeholder={isAdmin ? 'Password' : 'Paste your invite token'}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{error && <p className="text-red-500 text-sm">{error}</p>}
|
||||||
|
<button
|
||||||
|
type="submit"
|
||||||
|
disabled={loading}
|
||||||
|
className="w-full bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-4 py-2 text-sm font-medium transition-colors"
|
||||||
|
>
|
||||||
|
{loading ? 'Logging in…' : 'Log in'}
|
||||||
|
</button>
|
||||||
|
</form>
|
||||||
|
<button
|
||||||
|
onClick={() => { setIsAdmin(a => !a); setError('') }}
|
||||||
|
className="mt-4 text-xs text-gray-400 hover:text-gray-600 dark:hover:text-gray-200 w-full text-center"
|
||||||
|
>
|
||||||
|
{isAdmin ? 'Use invite token instead' : 'Admin login'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
196
frontend/src/pages/PresetsPage.tsx
Normal file
196
frontend/src/pages/PresetsPage.tsx
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import React, { useState } from 'react'
|
||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { listPresets, savePreset, getPreset, loadPreset, deletePreset, PresetMeta } from '../api/client'
|
||||||
|
|
||||||
|
export default function PresetsPage() {
|
||||||
|
const qc = useQueryClient()
|
||||||
|
const [newName, setNewName] = useState('')
|
||||||
|
const [newDescription, setNewDescription] = useState('')
|
||||||
|
const [savingError, setSavingError] = useState('')
|
||||||
|
const [expanded, setExpanded] = useState<string | null>(null)
|
||||||
|
const [message, setMessage] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const { data, isLoading } = useQuery({ queryKey: ['presets'], queryFn: listPresets })
|
||||||
|
|
||||||
|
const { data: presetDetail } = useQuery({
|
||||||
|
queryKey: ['preset', expanded],
|
||||||
|
queryFn: () => getPreset(expanded!),
|
||||||
|
enabled: !!expanded,
|
||||||
|
})
|
||||||
|
|
||||||
|
const saveMut = useMutation({
|
||||||
|
mutationFn: ({ name, description }: { name: string; description: string }) =>
|
||||||
|
savePreset(name, description || undefined),
|
||||||
|
onSuccess: () => {
|
||||||
|
qc.invalidateQueries({ queryKey: ['presets'] })
|
||||||
|
setNewName('')
|
||||||
|
setNewDescription('')
|
||||||
|
setSavingError('')
|
||||||
|
setMessage('Preset saved.')
|
||||||
|
},
|
||||||
|
onError: (err) => setSavingError(err instanceof Error ? err.message : String(err)),
|
||||||
|
})
|
||||||
|
|
||||||
|
const loadMut = useMutation({
|
||||||
|
mutationFn: (name: string) => loadPreset(name),
|
||||||
|
onSuccess: (_, name) => {
|
||||||
|
qc.invalidateQueries({ queryKey: ['state'] })
|
||||||
|
qc.invalidateQueries({ queryKey: ['workflowInputs'] })
|
||||||
|
setMessage(`Loaded preset: ${name}`)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const deleteMut = useMutation({
|
||||||
|
mutationFn: (name: string) => deletePreset(name),
|
||||||
|
onSuccess: () => {
|
||||||
|
qc.invalidateQueries({ queryKey: ['presets'] })
|
||||||
|
setMessage('Preset deleted.')
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-xl mx-auto space-y-6">
|
||||||
|
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Presets</h1>
|
||||||
|
|
||||||
|
{message && (
|
||||||
|
<div className="text-sm text-blue-600 dark:text-blue-400">{message}</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Save current state */}
|
||||||
|
<div className="space-y-2">
|
||||||
|
<div className="flex flex-col sm:flex-row gap-2">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={newName}
|
||||||
|
onChange={e => setNewName(e.target.value)}
|
||||||
|
placeholder="Preset name"
|
||||||
|
className="flex-1 border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
onClick={() => { setSavingError(''); saveMut.mutate({ name: newName, description: newDescription }) }}
|
||||||
|
disabled={!newName.trim() || saveMut.isPending}
|
||||||
|
className="bg-blue-600 hover:bg-blue-700 disabled:opacity-50 text-white rounded px-3 py-2 text-sm font-medium"
|
||||||
|
>
|
||||||
|
Save current state
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={newDescription}
|
||||||
|
onChange={e => setNewDescription(e.target.value)}
|
||||||
|
placeholder="Description (optional)"
|
||||||
|
className="w-full border border-gray-300 dark:border-gray-600 rounded px-3 py-2 text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{savingError && <p className="text-red-500 text-sm">{savingError}</p>}
|
||||||
|
|
||||||
|
{/* Preset list */}
|
||||||
|
{isLoading ? (
|
||||||
|
<div className="text-sm text-gray-400">Loading presets…</div>
|
||||||
|
) : (data?.presets ?? []).length === 0 ? (
|
||||||
|
<p className="text-sm text-gray-400">No presets saved yet.</p>
|
||||||
|
) : (
|
||||||
|
<ul className="divide-y divide-gray-200 dark:divide-gray-700 border border-gray-200 dark:border-gray-700 rounded">
|
||||||
|
{(data?.presets ?? []).map((preset: PresetMeta) => (
|
||||||
|
<li key={preset.name} className="px-3 py-2 space-y-1">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<button
|
||||||
|
onClick={() => setExpanded(expanded === preset.name ? null : preset.name)}
|
||||||
|
className="text-sm font-medium text-gray-700 dark:text-gray-300 hover:text-blue-600 dark:hover:text-blue-400 text-left"
|
||||||
|
>
|
||||||
|
{preset.name}
|
||||||
|
{preset.owner && (
|
||||||
|
<span className="ml-2 text-xs text-gray-400 dark:text-gray-500 font-normal">
|
||||||
|
{preset.owner}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<button
|
||||||
|
onClick={() => { setMessage(null); loadMut.mutate(preset.name) }}
|
||||||
|
disabled={loadMut.isPending}
|
||||||
|
className="text-xs bg-green-600 hover:bg-green-700 disabled:opacity-50 text-white rounded px-2 py-1"
|
||||||
|
>
|
||||||
|
Load
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => { setMessage(null); deleteMut.mutate(preset.name) }}
|
||||||
|
className="text-xs bg-red-600 hover:bg-red-700 text-white rounded px-2 py-1"
|
||||||
|
>
|
||||||
|
Delete
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{preset.description && (
|
||||||
|
<p className="text-xs italic text-gray-400 dark:text-gray-500">{preset.description}</p>
|
||||||
|
)}
|
||||||
|
{expanded === preset.name && presetDetail && (
|
||||||
|
<PresetDetail data={presetDetail} />
|
||||||
|
)}
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function PresetDetail({ data }: { data: Record<string, unknown> }) {
|
||||||
|
const state = (data.state ?? {}) as Record<string, unknown>
|
||||||
|
const { prompt, negative_prompt, seed, ...otherOverrides } = state
|
||||||
|
const hasOther = Object.keys(otherOverrides).length > 0
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="mt-2 space-y-2 text-xs">
|
||||||
|
{prompt != null && (
|
||||||
|
<div>
|
||||||
|
<span className="font-semibold text-gray-600 dark:text-gray-400">Prompt</span>
|
||||||
|
<p className="mt-0.5 bg-gray-50 dark:bg-gray-800 rounded p-2 text-gray-700 dark:text-gray-300 whitespace-pre-wrap break-words">
|
||||||
|
{String(prompt)}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{negative_prompt != null && (
|
||||||
|
<div>
|
||||||
|
<span className="font-semibold text-gray-600 dark:text-gray-400">Negative prompt</span>
|
||||||
|
<p className="mt-0.5 bg-gray-50 dark:bg-gray-800 rounded p-2 text-gray-700 dark:text-gray-300 whitespace-pre-wrap break-words">
|
||||||
|
{String(negative_prompt)}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{seed != null && (
|
||||||
|
<div className="flex gap-2 items-baseline">
|
||||||
|
<span className="font-semibold text-gray-600 dark:text-gray-400">Seed</span>
|
||||||
|
<span className="font-mono text-gray-700 dark:text-gray-300">
|
||||||
|
{String(seed)}
|
||||||
|
{seed === -1 && <span className="ml-1 text-gray-400">(random)</span>}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{hasOther && (
|
||||||
|
<div>
|
||||||
|
<span className="font-semibold text-gray-600 dark:text-gray-400">Other overrides</span>
|
||||||
|
<table className="mt-0.5 w-full text-xs border-collapse">
|
||||||
|
<tbody>
|
||||||
|
{Object.entries(otherOverrides).map(([k, v]) => (
|
||||||
|
<tr key={k} className="border-b border-gray-100 dark:border-gray-700">
|
||||||
|
<td className="py-0.5 pr-3 font-mono text-gray-500 dark:text-gray-400 whitespace-nowrap">{k}</td>
|
||||||
|
<td className="py-0.5 text-gray-700 dark:text-gray-300 break-all">{JSON.stringify(v)}</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{!!data.workflow && (
|
||||||
|
<div className="text-green-600 dark:text-green-400">Includes workflow template</div>
|
||||||
|
)}
|
||||||
|
<details>
|
||||||
|
<summary className="cursor-pointer text-gray-400 hover:text-gray-600 dark:hover:text-gray-300">Raw JSON</summary>
|
||||||
|
<pre className="mt-1 text-xs bg-gray-100 dark:bg-gray-800 rounded p-2 overflow-auto max-h-48 text-gray-600 dark:text-gray-400">
|
||||||
|
{JSON.stringify(data, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</details>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
100
frontend/src/pages/ServerPage.tsx
Normal file
100
frontend/src/pages/ServerPage.tsx
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import React, { useEffect, useRef, useState } from 'react'
|
||||||
|
import { useQuery, useMutation } from '@tanstack/react-query'
|
||||||
|
import { getServerStatus, serverAction, tailLogs } from '../api/client'
|
||||||
|
|
||||||
|
const ACTIONS = ['start', 'stop', 'restart'] as const
|
||||||
|
|
||||||
|
export default function ServerPage() {
|
||||||
|
const [actionMsg, setActionMsg] = useState<string | null>(null)
|
||||||
|
const logRef = useRef<HTMLPreElement>(null)
|
||||||
|
|
||||||
|
const { data: srv, refetch: refetchStatus } = useQuery({
|
||||||
|
queryKey: ['server', 'status'],
|
||||||
|
queryFn: getServerStatus,
|
||||||
|
refetchInterval: 5000,
|
||||||
|
})
|
||||||
|
|
||||||
|
const { data: logsData, refetch: refetchLogs } = useQuery({
|
||||||
|
queryKey: ['logs'],
|
||||||
|
queryFn: () => tailLogs(200),
|
||||||
|
refetchInterval: 2000,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Auto-scroll log to bottom
|
||||||
|
useEffect(() => {
|
||||||
|
if (logRef.current) {
|
||||||
|
logRef.current.scrollTop = logRef.current.scrollHeight
|
||||||
|
}
|
||||||
|
}, [logsData])
|
||||||
|
|
||||||
|
const actionMut = useMutation({
|
||||||
|
mutationFn: (action: string) => serverAction(action),
|
||||||
|
onSuccess: (_, action) => {
|
||||||
|
setActionMsg(`${action} sent.`)
|
||||||
|
setTimeout(() => { setActionMsg(null); refetchStatus() }, 2000)
|
||||||
|
},
|
||||||
|
onError: (err) => setActionMsg(`Error: ${err instanceof Error ? err.message : String(err)}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
const stateColor = srv?.service_state === 'SERVICE_RUNNING'
|
||||||
|
? 'text-green-600 dark:text-green-400'
|
||||||
|
: srv?.service_state === 'SERVICE_STOPPED'
|
||||||
|
? 'text-red-500 dark:text-red-400'
|
||||||
|
: 'text-yellow-500'
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-2xl mx-auto space-y-6">
|
||||||
|
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Server</h1>
|
||||||
|
|
||||||
|
{/* Status */}
|
||||||
|
<div className="bg-white dark:bg-gray-800 rounded border border-gray-200 dark:border-gray-700 p-4 text-sm space-y-1">
|
||||||
|
<p>
|
||||||
|
State: <span className={`font-medium ${stateColor}`}>{srv?.service_state ?? '—'}</span>
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
HTTP: <span className={srv?.http_reachable ? 'text-green-600 dark:text-green-400' : 'text-red-500'}>
|
||||||
|
{srv == null ? '—' : srv.http_reachable ? '✅ reachable' : '❌ unreachable'}
|
||||||
|
</span>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Controls */}
|
||||||
|
<div className="flex gap-2">
|
||||||
|
{ACTIONS.map(a => (
|
||||||
|
<button
|
||||||
|
key={a}
|
||||||
|
onClick={() => { setActionMsg(null); actionMut.mutate(a) }}
|
||||||
|
disabled={actionMut.isPending}
|
||||||
|
className={`text-sm rounded px-3 py-2 font-medium disabled:opacity-50 transition-colors ${
|
||||||
|
a === 'stop' ? 'bg-red-600 hover:bg-red-700 text-white'
|
||||||
|
: a === 'restart' ? 'bg-yellow-500 hover:bg-yellow-600 text-white'
|
||||||
|
: 'bg-green-600 hover:bg-green-700 text-white'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{a.charAt(0).toUpperCase() + a.slice(1)}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
{actionMsg && <p className="text-sm text-blue-600 dark:text-blue-400">{actionMsg}</p>}
|
||||||
|
|
||||||
|
{/* Log tail */}
|
||||||
|
<div className="space-y-1">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<p className="text-sm font-medium text-gray-700 dark:text-gray-300">Log tail (last 200 lines)</p>
|
||||||
|
<button
|
||||||
|
onClick={() => refetchLogs()}
|
||||||
|
className="text-xs text-gray-400 hover:text-gray-600 dark:hover:text-gray-200"
|
||||||
|
>
|
||||||
|
Refresh
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<pre
|
||||||
|
ref={logRef}
|
||||||
|
className="bg-gray-900 text-gray-100 text-xs rounded p-3 h-72 overflow-y-auto whitespace-pre-wrap font-mono"
|
||||||
|
>
|
||||||
|
{(logsData?.lines ?? []).join('\n') || 'No log lines available.'}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
119
frontend/src/pages/SharePage.tsx
Normal file
119
frontend/src/pages/SharePage.tsx
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { useParams, Link } from 'react-router-dom'
|
||||||
|
import { useQuery } from '@tanstack/react-query'
|
||||||
|
import { getShareFileUrl } from '../api/client'
|
||||||
|
|
||||||
|
interface ShareData {
|
||||||
|
prompt_id: string
|
||||||
|
created_at: string
|
||||||
|
overrides: Record<string, unknown>
|
||||||
|
seed?: number
|
||||||
|
images: Array<{ filename: string; data: string | null; mime_type: string }>
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function SharePage() {
|
||||||
|
const { token } = useParams<{ token: string }>()
|
||||||
|
|
||||||
|
const { data, isLoading, error } = useQuery({
|
||||||
|
queryKey: ['share', token],
|
||||||
|
queryFn: async () => {
|
||||||
|
const res = await fetch(`/api/share/${token}`, { credentials: 'include' })
|
||||||
|
if (!res.ok) {
|
||||||
|
const msg = await res.text().catch(() => res.statusText)
|
||||||
|
throw Object.assign(new Error(msg), { status: res.status })
|
||||||
|
}
|
||||||
|
return res.json() as Promise<ShareData>
|
||||||
|
},
|
||||||
|
retry: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen flex items-center justify-center bg-gray-100 dark:bg-gray-900">
|
||||||
|
<p className="text-sm text-gray-400">Loading…</p>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const status = (error as any)?.status
|
||||||
|
|
||||||
|
if (status === 401) {
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen flex items-center justify-center bg-gray-100 dark:bg-gray-900">
|
||||||
|
<div className="bg-white dark:bg-gray-800 shadow rounded-lg p-8 w-full max-w-sm text-center space-y-4">
|
||||||
|
<p className="text-gray-700 dark:text-gray-300">You need to be logged in to view this shared link.</p>
|
||||||
|
<Link
|
||||||
|
to="/login"
|
||||||
|
className="inline-block bg-blue-600 hover:bg-blue-700 text-white rounded px-4 py-2 text-sm font-medium transition-colors"
|
||||||
|
>
|
||||||
|
Log in
|
||||||
|
</Link>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen flex items-center justify-center bg-gray-100 dark:bg-gray-900">
|
||||||
|
<div className="bg-white dark:bg-gray-800 shadow rounded-lg p-8 w-full max-w-sm text-center">
|
||||||
|
<p className="text-gray-700 dark:text-gray-300">This share link has been revoked or does not exist.</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen bg-gray-100 dark:bg-gray-900 py-8 px-4">
|
||||||
|
<div className="max-w-3xl mx-auto space-y-6">
|
||||||
|
<div className="bg-white dark:bg-gray-800 shadow rounded-lg p-6 space-y-4">
|
||||||
|
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Shared Generation</h1>
|
||||||
|
|
||||||
|
<div className="text-sm text-gray-500 dark:text-gray-400 space-y-1">
|
||||||
|
<p>Generated: {data && new Date(data.created_at).toLocaleString()}</p>
|
||||||
|
{data?.seed != null && (
|
||||||
|
<p>Seed: <span className="font-mono">{data.seed}</span></p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{data && Object.keys(data.overrides).length > 0 && (
|
||||||
|
<details>
|
||||||
|
<summary className="text-xs text-gray-400 cursor-pointer">Overrides</summary>
|
||||||
|
<pre className="text-xs bg-gray-100 dark:bg-gray-700 rounded p-2 mt-1 overflow-auto max-h-40 text-gray-600 dark:text-gray-400">
|
||||||
|
{JSON.stringify(data.overrides, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</details>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Images / Videos */}
|
||||||
|
<div className="flex gap-3 flex-wrap">
|
||||||
|
{(data?.images ?? []).map((img, i) => {
|
||||||
|
if (img.mime_type.startsWith('video/')) {
|
||||||
|
return (
|
||||||
|
<video
|
||||||
|
key={i}
|
||||||
|
src={getShareFileUrl(token!, img.filename)}
|
||||||
|
controls
|
||||||
|
className="rounded max-h-80 max-w-full"
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<img
|
||||||
|
key={i}
|
||||||
|
src={`data:${img.mime_type};base64,${img.data}`}
|
||||||
|
alt={img.filename}
|
||||||
|
className="rounded max-h-80 max-w-full object-contain border border-gray-200 dark:border-gray-700"
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p className="text-center text-xs text-gray-400">
|
||||||
|
<Link to="/login" className="hover:underline">ComfyUI Bot</Link>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
77
frontend/src/pages/StatusPage.tsx
Normal file
77
frontend/src/pages/StatusPage.tsx
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { useAuth } from '../hooks/useAuth'
|
||||||
|
import { useStatus } from '../hooks/useStatus'
|
||||||
|
|
||||||
|
export default function StatusPage() {
|
||||||
|
const { user } = useAuth()
|
||||||
|
const { status, executingNode } = useStatus({ enabled: !!user })
|
||||||
|
|
||||||
|
const { bot, comfy, service, upload } = status
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-2xl mx-auto space-y-6">
|
||||||
|
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Status</h1>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4">
|
||||||
|
|
||||||
|
{/* Bot */}
|
||||||
|
<Card title="Bot">
|
||||||
|
<Row label="Latency" value={bot ? `${bot.latency_ms} ms` : '—'} />
|
||||||
|
<Row label="Uptime" value={bot?.uptime ?? '—'} />
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
{/* ComfyUI */}
|
||||||
|
<Card title="ComfyUI">
|
||||||
|
<Row label="Server" value={comfy?.server ?? '—'} />
|
||||||
|
<Row
|
||||||
|
label="Reachable"
|
||||||
|
value={comfy?.reachable == null ? '—' : comfy.reachable ? '✅ yes' : '❌ no'}
|
||||||
|
/>
|
||||||
|
<Row label="Queue running" value={String(comfy?.queue_running ?? 0)} />
|
||||||
|
<Row label="Queue pending" value={String(comfy?.queue_pending ?? 0)} />
|
||||||
|
<Row label="Workflow loaded" value={comfy?.workflow_loaded ? '✓' : '✗'} />
|
||||||
|
<Row label="Last seed" value={comfy?.last_seed != null ? String(comfy.last_seed) : '—'} />
|
||||||
|
<Row label="Total generated" value={String(comfy?.total_generated ?? 0)} />
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
{/* Service */}
|
||||||
|
<Card title="Service">
|
||||||
|
<Row label="State" value={service?.state ?? '—'} />
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
{/* Auto-upload */}
|
||||||
|
<Card title="Auto-upload">
|
||||||
|
<Row label="Configured" value={upload?.configured ? '✓' : '✗'} />
|
||||||
|
<Row label="Running" value={upload?.running ? '⏳ yes' : 'idle'} />
|
||||||
|
<Row label="Total ok" value={String(upload?.total_ok ?? 0)} />
|
||||||
|
<Row label="Total fail" value={String(upload?.total_fail ?? 0)} />
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Executing node */}
|
||||||
|
{executingNode && (
|
||||||
|
<div className="bg-blue-50 dark:bg-blue-900/30 border border-blue-200 dark:border-blue-700 rounded p-3 text-sm text-blue-700 dark:text-blue-300">
|
||||||
|
Executing node: <strong>{executingNode}</strong>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function Card({ title, children }: { title: string; children: React.ReactNode }) {
|
||||||
|
return (
|
||||||
|
<div className="bg-white dark:bg-gray-800 rounded border border-gray-200 dark:border-gray-700 p-4">
|
||||||
|
<p className="text-xs font-semibold uppercase tracking-wide text-gray-400 dark:text-gray-500 mb-2">{title}</p>
|
||||||
|
<dl className="space-y-1">{children}</dl>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function Row({ label, value }: { label: string; value: string }) {
|
||||||
|
return (
|
||||||
|
<div className="flex justify-between text-sm">
|
||||||
|
<dt className="text-gray-500 dark:text-gray-400">{label}</dt>
|
||||||
|
<dd className="text-gray-800 dark:text-gray-200 font-mono text-right">{value}</dd>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
151
frontend/src/pages/WorkflowPage.tsx
Normal file
151
frontend/src/pages/WorkflowPage.tsx
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import React, { useRef, useState } from 'react'
|
||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import {
|
||||||
|
getWorkflow,
|
||||||
|
getWorkflowInputs,
|
||||||
|
listWorkflowFiles,
|
||||||
|
uploadWorkflow,
|
||||||
|
loadWorkflow,
|
||||||
|
} from '../api/client'
|
||||||
|
|
||||||
|
export default function WorkflowPage() {
|
||||||
|
const qc = useQueryClient()
|
||||||
|
const fileRef = useRef<HTMLInputElement>(null)
|
||||||
|
const [uploading, setUploading] = useState(false)
|
||||||
|
const [loadingFile, setLoadingFile] = useState<string | null>(null)
|
||||||
|
const [message, setMessage] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const { data: wf } = useQuery({ queryKey: ['workflow'], queryFn: getWorkflow })
|
||||||
|
const { data: inputs } = useQuery({
|
||||||
|
queryKey: ['workflow', 'inputs'],
|
||||||
|
queryFn: getWorkflowInputs,
|
||||||
|
enabled: wf?.loaded ?? false,
|
||||||
|
})
|
||||||
|
const { data: filesData, refetch: refetchFiles } = useQuery({
|
||||||
|
queryKey: ['workflow', 'files'],
|
||||||
|
queryFn: listWorkflowFiles,
|
||||||
|
})
|
||||||
|
|
||||||
|
const loadMut = useMutation({
|
||||||
|
mutationFn: (filename: string) => loadWorkflow(filename),
|
||||||
|
onSuccess: (_, filename) => {
|
||||||
|
setMessage(`Loaded: ${filename}`)
|
||||||
|
qc.invalidateQueries({ queryKey: ['workflow'] })
|
||||||
|
qc.invalidateQueries({ queryKey: ['state'] })
|
||||||
|
setLoadingFile(null)
|
||||||
|
},
|
||||||
|
onError: (err) => {
|
||||||
|
setMessage(`Error: ${err instanceof Error ? err.message : String(err)}`)
|
||||||
|
setLoadingFile(null)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const handleUpload = async () => {
|
||||||
|
if (!fileRef.current?.files?.length) return
|
||||||
|
setUploading(true)
|
||||||
|
setMessage(null)
|
||||||
|
try {
|
||||||
|
const res = await uploadWorkflow(fileRef.current.files[0])
|
||||||
|
setMessage(`Uploaded: ${res.filename ?? 'ok'}`)
|
||||||
|
refetchFiles()
|
||||||
|
if (fileRef.current) fileRef.current.value = ''
|
||||||
|
} catch (err) {
|
||||||
|
setMessage(`Upload error: ${err instanceof Error ? err.message : String(err)}`)
|
||||||
|
} finally {
|
||||||
|
setUploading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const allInputs = [...(inputs?.common ?? []), ...(inputs?.advanced ?? [])]
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-2xl mx-auto space-y-6">
|
||||||
|
<h1 className="text-xl font-bold text-gray-800 dark:text-gray-100">Workflow</h1>
|
||||||
|
|
||||||
|
{/* Current workflow */}
|
||||||
|
<div className="bg-white dark:bg-gray-800 rounded border border-gray-200 dark:border-gray-700 p-4 text-sm space-y-1">
|
||||||
|
<p className="font-medium text-gray-700 dark:text-gray-300">Current workflow</p>
|
||||||
|
{wf?.loaded ? (
|
||||||
|
<>
|
||||||
|
<p className="text-gray-500 dark:text-gray-400">{wf.last_workflow_file ?? '(loaded from state)'}</p>
|
||||||
|
<p className="text-gray-500 dark:text-gray-400">{wf.node_count} node(s) detected</p>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<p className="text-gray-400">No workflow loaded</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{message && (
|
||||||
|
<div className="text-sm text-blue-600 dark:text-blue-400">{message}</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Upload */}
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<label className="cursor-pointer text-sm bg-blue-600 hover:bg-blue-700 text-white rounded px-3 py-2">
|
||||||
|
{uploading ? 'Uploading…' : 'Upload workflow JSON'}
|
||||||
|
<input
|
||||||
|
ref={fileRef}
|
||||||
|
type="file"
|
||||||
|
accept=".json"
|
||||||
|
className="hidden"
|
||||||
|
onChange={handleUpload}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
<span className="text-xs text-gray-400">Uploads to workflows/ folder</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Available files */}
|
||||||
|
{(filesData?.files ?? []).length > 0 && (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<p className="text-sm font-medium text-gray-700 dark:text-gray-300">Available workflows</p>
|
||||||
|
<ul className="divide-y divide-gray-200 dark:divide-gray-700 border border-gray-200 dark:border-gray-700 rounded">
|
||||||
|
{(filesData?.files ?? []).map(f => (
|
||||||
|
<li key={f} className="flex items-center justify-between px-3 py-2 text-sm">
|
||||||
|
<span className={`text-gray-700 dark:text-gray-300 ${wf?.last_workflow_file === f ? 'font-semibold text-blue-600 dark:text-blue-400' : ''}`}>
|
||||||
|
{f}
|
||||||
|
{wf?.last_workflow_file === f && <span className="ml-1 text-xs">(active)</span>}
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
onClick={() => { setLoadingFile(f); setMessage(null); loadMut.mutate(f) }}
|
||||||
|
disabled={loadingFile === f}
|
||||||
|
className="text-xs bg-gray-200 dark:bg-gray-700 hover:bg-gray-300 dark:hover:bg-gray-600 rounded px-2 py-1 disabled:opacity-50"
|
||||||
|
>
|
||||||
|
{loadingFile === f ? 'Loading…' : 'Load'}
|
||||||
|
</button>
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Discovered inputs summary */}
|
||||||
|
{inputs && allInputs.length > 0 && (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<p className="text-sm font-medium text-gray-700 dark:text-gray-300">Discovered inputs ({allInputs.length})</p>
|
||||||
|
<div className="overflow-x-auto">
|
||||||
|
<table className="w-full text-xs border-collapse border border-gray-200 dark:border-gray-700 rounded">
|
||||||
|
<thead>
|
||||||
|
<tr className="bg-gray-100 dark:bg-gray-700">
|
||||||
|
<th className="text-left p-2 border-b border-gray-200 dark:border-gray-600">Key</th>
|
||||||
|
<th className="text-left p-2 border-b border-gray-200 dark:border-gray-600">Label</th>
|
||||||
|
<th className="text-left p-2 border-b border-gray-200 dark:border-gray-600">Type</th>
|
||||||
|
<th className="text-left p-2 border-b border-gray-200 dark:border-gray-600">Common</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{allInputs.map(inp => (
|
||||||
|
<tr key={inp.key} className="border-b border-gray-100 dark:border-gray-700 last:border-0">
|
||||||
|
<td className="p-2 font-mono text-gray-600 dark:text-gray-400">{inp.key}</td>
|
||||||
|
<td className="p-2 text-gray-700 dark:text-gray-300">{inp.label}</td>
|
||||||
|
<td className="p-2 text-gray-500 dark:text-gray-400">{inp.input_type}</td>
|
||||||
|
<td className="p-2 text-gray-500 dark:text-gray-400">{inp.is_common ? '✓' : ''}</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
9
frontend/tailwind.config.js
Normal file
9
frontend/tailwind.config.js
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
/** @type {import('tailwindcss').Config} */
|
||||||
|
export default {
|
||||||
|
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
|
||||||
|
darkMode: 'class',
|
||||||
|
theme: {
|
||||||
|
extend: {},
|
||||||
|
},
|
||||||
|
plugins: [],
|
||||||
|
}
|
||||||
24
frontend/tsconfig.json
Normal file
24
frontend/tsconfig.json
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "ES2020",
|
||||||
|
"useDefineForClassFields": true,
|
||||||
|
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"resolveJsonModule": true,
|
||||||
|
"isolatedModules": true,
|
||||||
|
"noEmit": true,
|
||||||
|
"jsx": "react-jsx",
|
||||||
|
"strict": true,
|
||||||
|
"noUnusedLocals": false,
|
||||||
|
"noUnusedParameters": false,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
"baseUrl": ".",
|
||||||
|
"paths": {
|
||||||
|
"@/*": ["src/*"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"include": ["src"]
|
||||||
|
}
|
||||||
25
frontend/vite.config.ts
Normal file
25
frontend/vite.config.ts
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import { defineConfig } from 'vite'
|
||||||
|
import react from '@vitejs/plugin-react'
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
build: {
|
||||||
|
outDir: '../web-static',
|
||||||
|
emptyOutDir: true,
|
||||||
|
},
|
||||||
|
server: {
|
||||||
|
port: 5173,
|
||||||
|
proxy: {
|
||||||
|
'/api': {
|
||||||
|
target: 'http://localhost:8080',
|
||||||
|
changeOrigin: true,
|
||||||
|
secure: false,
|
||||||
|
},
|
||||||
|
'/ws': {
|
||||||
|
target: 'ws://localhost:8080',
|
||||||
|
ws: true,
|
||||||
|
changeOrigin: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
322
generation_db.py
Normal file
322
generation_db.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
"""
|
||||||
|
generation_db.py
|
||||||
|
================
|
||||||
|
|
||||||
|
SQLite persistence for ComfyUI generation history and output file blobs.
|
||||||
|
|
||||||
|
Two tables
|
||||||
|
----------
|
||||||
|
generation_history : one row per prompt submitted to ComfyUI
|
||||||
|
generation_files : one row per output file (image / video) as a BLOB
|
||||||
|
|
||||||
|
The module-level ``_DB_PATH`` is set by :func:`init_db`; all other
|
||||||
|
functions use that path so callers never need to pass it around.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_DB_PATH: Path = Path(__file__).parent / "generation_history.db"
|
||||||
|
|
||||||
|
_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS generation_history (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
prompt_id TEXT UNIQUE NOT NULL,
|
||||||
|
source TEXT NOT NULL,
|
||||||
|
user_label TEXT,
|
||||||
|
overrides TEXT,
|
||||||
|
seed INTEGER,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS generation_files (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
generation_id INTEGER NOT NULL REFERENCES generation_history(id),
|
||||||
|
filename TEXT NOT NULL,
|
||||||
|
file_data BLOB NOT NULL,
|
||||||
|
mime_type TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS generation_shares (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
share_token TEXT UNIQUE NOT NULL,
|
||||||
|
prompt_id TEXT NOT NULL,
|
||||||
|
owner_label TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _connect(db_path: Path | None = None) -> sqlite3.Connection:
|
||||||
|
path = db_path if db_path is not None else _DB_PATH
|
||||||
|
conn = sqlite3.connect(str(path), check_same_thread=False)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_mime(data: bytes) -> str:
|
||||||
|
"""Detect MIME type from magic bytes."""
|
||||||
|
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||||
|
return "image/png"
|
||||||
|
if data[:2] == b"\xff\xd8":
|
||||||
|
return "image/jpeg"
|
||||||
|
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP":
|
||||||
|
return "image/webp"
|
||||||
|
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"AVI ":
|
||||||
|
return "video/x-msvideo"
|
||||||
|
if len(data) >= 8 and data[4:8] == b"ftyp":
|
||||||
|
return "video/mp4"
|
||||||
|
if data[:4] == b"\x1aE\xdf\xa3": # EBML (WebM/MKV)
|
||||||
|
return "video/webm"
|
||||||
|
return "application/octet-stream"
|
||||||
|
|
||||||
|
|
||||||
|
def init_db(db_path: Path = _DB_PATH) -> None:
|
||||||
|
"""Create tables if they don't exist. Accepts a path for testability."""
|
||||||
|
global _DB_PATH
|
||||||
|
_DB_PATH = db_path
|
||||||
|
with _connect(db_path) as conn:
|
||||||
|
conn.executescript(_SCHEMA)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def record_generation(
|
||||||
|
prompt_id: str,
|
||||||
|
source: str,
|
||||||
|
user_label: str | None,
|
||||||
|
overrides_dict: dict[str, Any] | None,
|
||||||
|
seed: int | None,
|
||||||
|
) -> int:
|
||||||
|
"""Insert a generation history row. Returns the auto-increment ``id``."""
|
||||||
|
overrides_json = json.dumps(overrides_dict) if overrides_dict is not None else None
|
||||||
|
created_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
with _connect() as conn:
|
||||||
|
cur = conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO generation_history
|
||||||
|
(prompt_id, source, user_label, overrides, seed, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(prompt_id, source, user_label, overrides_json, seed, created_at),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return cur.lastrowid # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
def record_file(generation_id: int, filename: str, file_data: bytes) -> None:
|
||||||
|
"""Insert a file BLOB row, auto-detecting MIME type from magic bytes."""
|
||||||
|
mime_type = _detect_mime(file_data)
|
||||||
|
with _connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO generation_files (generation_id, filename, file_data, mime_type)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(generation_id, filename, file_data, mime_type),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def _rows_to_history(conn: sqlite3.Connection, rows) -> list[dict]:
|
||||||
|
"""Convert raw generation_history rows (with optional share_token) to dicts."""
|
||||||
|
result: list[dict] = []
|
||||||
|
for row in rows:
|
||||||
|
d = dict(row)
|
||||||
|
if d["overrides"]:
|
||||||
|
try:
|
||||||
|
d["overrides"] = json.loads(d["overrides"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
d["overrides"] = {}
|
||||||
|
else:
|
||||||
|
d["overrides"] = {}
|
||||||
|
|
||||||
|
files = conn.execute(
|
||||||
|
"SELECT filename FROM generation_files WHERE generation_id = ?",
|
||||||
|
(d["id"],),
|
||||||
|
).fetchall()
|
||||||
|
d["file_paths"] = [f["filename"] for f in files]
|
||||||
|
result.append(d)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_history(limit: int = 50) -> list[dict]:
|
||||||
|
"""Return recent generation rows (newest first) with a ``file_paths`` list."""
|
||||||
|
with _connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||||
|
s.share_token
|
||||||
|
FROM generation_history h
|
||||||
|
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
|
||||||
|
ORDER BY h.id DESC LIMIT ?
|
||||||
|
""",
|
||||||
|
(limit,),
|
||||||
|
).fetchall()
|
||||||
|
return _rows_to_history(conn, rows)
|
||||||
|
|
||||||
|
|
||||||
|
def get_history_for_user(user_label: str, limit: int = 50) -> list[dict]:
|
||||||
|
"""Return recent generation rows for a specific user (newest first)."""
|
||||||
|
with _connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||||
|
s.share_token
|
||||||
|
FROM generation_history h
|
||||||
|
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
|
||||||
|
WHERE h.user_label = ?
|
||||||
|
ORDER BY h.id DESC LIMIT ?
|
||||||
|
""",
|
||||||
|
(user_label, user_label, limit),
|
||||||
|
).fetchall()
|
||||||
|
return _rows_to_history(conn, rows)
|
||||||
|
|
||||||
|
|
||||||
|
def get_generation(prompt_id: str) -> dict | None:
|
||||||
|
"""Return the generation_history row for *prompt_id*, or None."""
|
||||||
|
with _connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT id, prompt_id, user_label FROM generation_history WHERE prompt_id = ?",
|
||||||
|
(prompt_id,),
|
||||||
|
).fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_generation_full(prompt_id: str) -> dict | None:
|
||||||
|
"""Return overrides (parsed) + seed for *prompt_id*, or None if not found."""
|
||||||
|
with _connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT prompt_id, user_label, overrides, seed FROM generation_history WHERE prompt_id = ?",
|
||||||
|
(prompt_id,),
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
d = dict(row)
|
||||||
|
if d["overrides"]:
|
||||||
|
try:
|
||||||
|
d["overrides"] = json.loads(d["overrides"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
d["overrides"] = {}
|
||||||
|
else:
|
||||||
|
d["overrides"] = {}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def search_history_for_user(user_label: str, query: str, limit: int = 50) -> list[dict]:
|
||||||
|
"""Return history rows where the overrides JSON contains *query* (case-insensitive)."""
|
||||||
|
with _connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||||
|
s.share_token
|
||||||
|
FROM generation_history h
|
||||||
|
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = ?
|
||||||
|
WHERE h.user_label = ? AND LOWER(h.overrides) LIKE LOWER(?)
|
||||||
|
ORDER BY h.id DESC LIMIT ?
|
||||||
|
""",
|
||||||
|
(user_label, user_label, f"%{query}%", limit),
|
||||||
|
).fetchall()
|
||||||
|
return _rows_to_history(conn, rows)
|
||||||
|
|
||||||
|
|
||||||
|
def search_history(query: str, limit: int = 50) -> list[dict]:
|
||||||
|
"""Admin version: search all users' history for *query* in overrides JSON."""
|
||||||
|
with _connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT h.id, h.prompt_id, h.source, h.user_label, h.overrides, h.seed, h.created_at,
|
||||||
|
s.share_token
|
||||||
|
FROM generation_history h
|
||||||
|
LEFT JOIN generation_shares s ON h.prompt_id = s.prompt_id AND s.owner_label = h.user_label
|
||||||
|
WHERE LOWER(h.overrides) LIKE LOWER(?)
|
||||||
|
ORDER BY h.id DESC LIMIT ?
|
||||||
|
""",
|
||||||
|
(f"%{query}%", limit),
|
||||||
|
).fetchall()
|
||||||
|
return _rows_to_history(conn, rows)
|
||||||
|
|
||||||
|
|
||||||
|
def create_share(prompt_id: str, owner_label: str) -> str:
|
||||||
|
"""Create a share token for *prompt_id*. Idempotent — returns the same token if one exists."""
|
||||||
|
token = secrets.token_urlsafe(32)
|
||||||
|
created_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
with _connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR IGNORE INTO generation_shares (share_token, prompt_id, owner_label, created_at)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(token, prompt_id, owner_label, created_at),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT share_token FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
|
||||||
|
(prompt_id, owner_label),
|
||||||
|
).fetchone()
|
||||||
|
return row["share_token"]
|
||||||
|
|
||||||
|
|
||||||
|
def revoke_share(prompt_id: str, owner_label: str) -> bool:
|
||||||
|
"""Delete the share token for *prompt_id*. Returns True if a row was deleted."""
|
||||||
|
with _connect() as conn:
|
||||||
|
cur = conn.execute(
|
||||||
|
"DELETE FROM generation_shares WHERE prompt_id = ? AND owner_label = ?",
|
||||||
|
(prompt_id, owner_label),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return cur.rowcount > 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_share_by_token(token: str) -> dict | None:
|
||||||
|
"""Return generation info for a share token, or None if not found/revoked."""
|
||||||
|
with _connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT h.prompt_id, h.overrides, h.seed, h.created_at
|
||||||
|
FROM generation_shares s
|
||||||
|
JOIN generation_history h ON h.prompt_id = s.prompt_id
|
||||||
|
WHERE s.share_token = ?
|
||||||
|
""",
|
||||||
|
(token,),
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
d = dict(row)
|
||||||
|
if d["overrides"]:
|
||||||
|
try:
|
||||||
|
d["overrides"] = json.loads(d["overrides"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
d["overrides"] = {}
|
||||||
|
else:
|
||||||
|
d["overrides"] = {}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def get_files(prompt_id: str) -> list[dict]:
|
||||||
|
"""Return all output files for *prompt_id* as ``[{filename, data, mime_type}]``."""
|
||||||
|
with _connect() as conn:
|
||||||
|
gen_row = conn.execute(
|
||||||
|
"SELECT id FROM generation_history WHERE prompt_id = ?",
|
||||||
|
(prompt_id,),
|
||||||
|
).fetchone()
|
||||||
|
if not gen_row:
|
||||||
|
return []
|
||||||
|
|
||||||
|
files = conn.execute(
|
||||||
|
"SELECT filename, file_data, mime_type FROM generation_files WHERE generation_id = ?",
|
||||||
|
(gen_row["id"],),
|
||||||
|
).fetchall()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"filename": f["filename"],
|
||||||
|
"data": bytes(f["file_data"]),
|
||||||
|
"mime_type": f["mime_type"],
|
||||||
|
}
|
||||||
|
for f in files
|
||||||
|
]
|
||||||
84
image_utils.py
Normal file
84
image_utils.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""
|
||||||
|
image_utils.py
|
||||||
|
==============
|
||||||
|
|
||||||
|
Shared image processing utilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from PIL import Image as _PILImage
|
||||||
|
_HAS_PIL = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_PIL = False
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DISCORD_MAX_BYTES = 8 * 1024 * 1024 # 8 MiB Discord free-tier upload limit
|
||||||
|
|
||||||
|
|
||||||
|
def compress_to_discord_limit(data: bytes, filename: str) -> tuple[bytes, str]:
|
||||||
|
"""
|
||||||
|
Compress image bytes to fit within Discord's 8 MiB upload limit.
|
||||||
|
|
||||||
|
Tries quality reduction first, then progressive downsizing.
|
||||||
|
Converts to JPEG if the source format (e.g. PNG) cannot be quality-compressed.
|
||||||
|
Returns (data, filename) — filename may change if format conversion is needed.
|
||||||
|
No-op if data is already within the limit.
|
||||||
|
"""
|
||||||
|
if len(data) <= DISCORD_MAX_BYTES:
|
||||||
|
logger.info(f"Image size is less than {DISCORD_MAX_BYTES}")
|
||||||
|
return data, filename
|
||||||
|
|
||||||
|
if not _HAS_PIL:
|
||||||
|
logger.warning("Pillow not installed — cannot compress %s (%d bytes), uploading as-is", filename, len(data))
|
||||||
|
return data, filename
|
||||||
|
|
||||||
|
suffix = Path(filename).suffix.lower()
|
||||||
|
stem = Path(filename).stem
|
||||||
|
|
||||||
|
img = _PILImage.open(io.BytesIO(data))
|
||||||
|
|
||||||
|
# PNG/GIF/BMP don't support lossy quality — convert to JPEG
|
||||||
|
logger.info("Checking file extension for convert to jpeg")
|
||||||
|
if suffix in (".jpg", ".jpeg"):
|
||||||
|
fmt, out_name = "JPEG", filename
|
||||||
|
elif suffix == ".webp":
|
||||||
|
fmt, out_name = "WEBP", filename
|
||||||
|
else:
|
||||||
|
fmt, out_name = "JPEG", stem + ".jpg"
|
||||||
|
if img.mode in ("RGBA", "P", "LA"):
|
||||||
|
img = img.convert("RGB")
|
||||||
|
logger.info("File extension checked")
|
||||||
|
|
||||||
|
# Round 1: quality reduction only
|
||||||
|
logger.info("# Round 1: quality reduction only")
|
||||||
|
for quality in (85, 70, 55, 40, 25, 15):
|
||||||
|
logger.info(f"# Round 1: Trying quality: {quality}")
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format=fmt, quality=quality)
|
||||||
|
if buf.tell() <= DISCORD_MAX_BYTES:
|
||||||
|
logger.info("Compressed %s at quality=%d: %d → %d bytes", filename, quality, len(data), buf.tell())
|
||||||
|
return buf.getvalue(), out_name
|
||||||
|
|
||||||
|
# Round 2: resize + low quality
|
||||||
|
logger.info("# Round 2: resize + low quality")
|
||||||
|
for scale in (0.75, 0.5, 0.35, 0.25):
|
||||||
|
w, h = img.size
|
||||||
|
resized = img.resize((int(w * scale), int(h * scale)), _PILImage.LANCZOS)
|
||||||
|
logger.info(f"# Round 2: Trying to resize: {resized.size}")
|
||||||
|
buf = io.BytesIO()
|
||||||
|
resized.save(buf, format=fmt, quality=15)
|
||||||
|
if buf.tell() <= DISCORD_MAX_BYTES:
|
||||||
|
logger.info("Compressed %s at scale=%.2f: %d → %d bytes", filename, scale, len(data), buf.tell())
|
||||||
|
return buf.getvalue(), out_name
|
||||||
|
|
||||||
|
logger.warning("Could not compress %s under %d bytes — uploading best effort", filename, DISCORD_MAX_BYTES)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format=fmt, quality=10)
|
||||||
|
return buf.getvalue(), out_name
|
||||||
232
input_image_db.py
Normal file
232
input_image_db.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
"""
|
||||||
|
input_image_db.py
|
||||||
|
=================
|
||||||
|
|
||||||
|
SQLite helpers for tracking Discord-channel-backed input images.
|
||||||
|
The database stores one row per attachment, so a single message with
|
||||||
|
multiple images produces multiple rows. The stable lookup key is the
|
||||||
|
auto-increment `id`; `(original_message_id, filename)` is a unique
|
||||||
|
composite constraint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
DB_PATH = Path(__file__).parent / "input_images.db"
|
||||||
|
|
||||||
|
_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS input_images (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
original_message_id INTEGER NOT NULL,
|
||||||
|
bot_reply_id INTEGER NOT NULL,
|
||||||
|
channel_id INTEGER NOT NULL,
|
||||||
|
filename TEXT NOT NULL,
|
||||||
|
is_active INTEGER NOT NULL DEFAULT 0,
|
||||||
|
image_data BLOB,
|
||||||
|
active_slot_key TEXT DEFAULT NULL,
|
||||||
|
UNIQUE(original_message_id, filename)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Live migrations applied on every startup — safe if column already exists
|
||||||
|
_MIGRATIONS = [
|
||||||
|
"ALTER TABLE input_images ADD COLUMN image_data BLOB",
|
||||||
|
"ALTER TABLE input_images ADD COLUMN active_slot_key TEXT DEFAULT NULL",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Columns returned by get_image / get_all_images (excludes the potentially large BLOB)
|
||||||
|
_SAFE_COLS = "id, original_message_id, bot_reply_id, channel_id, filename, is_active, active_slot_key"
|
||||||
|
|
||||||
|
|
||||||
|
def _connect() -> sqlite3.Connection:
|
||||||
|
conn = sqlite3.connect(str(DB_PATH))
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def init_db() -> None:
|
||||||
|
"""Create the input_images table if it does not exist, and run column migrations."""
|
||||||
|
with _connect() as conn:
|
||||||
|
conn.execute(_SCHEMA)
|
||||||
|
for stmt in _MIGRATIONS:
|
||||||
|
try:
|
||||||
|
conn.execute(stmt)
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
pass # column already exists
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_image(
|
||||||
|
original_message_id: int,
|
||||||
|
bot_reply_id: int,
|
||||||
|
channel_id: int,
|
||||||
|
filename: str,
|
||||||
|
image_data: bytes | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Insert a new image record or update an existing one.
|
||||||
|
|
||||||
|
Returns the stable row ``id`` (used as the persistent view key).
|
||||||
|
When *image_data* is provided it is stored as a BLOB; on UPDATE it is
|
||||||
|
only overwritten when not None.
|
||||||
|
"""
|
||||||
|
with _connect() as conn:
|
||||||
|
existing = conn.execute(
|
||||||
|
"SELECT id FROM input_images WHERE original_message_id = ? AND filename = ?",
|
||||||
|
(original_message_id, filename),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
if image_data is not None:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE input_images SET bot_reply_id = ?, channel_id = ?, image_data = ? WHERE id = ?",
|
||||||
|
(bot_reply_id, channel_id, image_data, existing["id"]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE input_images SET bot_reply_id = ?, channel_id = ? WHERE id = ?",
|
||||||
|
(bot_reply_id, channel_id, existing["id"]),
|
||||||
|
)
|
||||||
|
row_id = existing["id"]
|
||||||
|
else:
|
||||||
|
cur = conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO input_images (original_message_id, bot_reply_id, channel_id, filename, is_active, image_data)
|
||||||
|
VALUES (?, ?, ?, ?, 0, ?)
|
||||||
|
""",
|
||||||
|
(original_message_id, bot_reply_id, channel_id, filename, image_data),
|
||||||
|
)
|
||||||
|
row_id = cur.lastrowid
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
return row_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_data(row_id: int) -> bytes | None:
|
||||||
|
"""Return the raw image bytes for a row, or None if the row is missing or has no data."""
|
||||||
|
with _connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT image_data FROM input_images WHERE id = ?",
|
||||||
|
(row_id,),
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return row["image_data"]
|
||||||
|
|
||||||
|
|
||||||
|
def activate_image_for_slot(row_id: int, slot_key: str, comfy_input_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Write the stored image bytes to ``{comfy_input_path}/ttb_{slot_key}{ext}``
|
||||||
|
and record the slot assignment in the DB.
|
||||||
|
|
||||||
|
Returns the basename of the written file (e.g. ``ttb_input_image.jpg``).
|
||||||
|
Raises ``ValueError`` if the row has no image_data (user must re-upload).
|
||||||
|
"""
|
||||||
|
data = get_image_data(row_id)
|
||||||
|
if data is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"No image data stored for row {row_id}. Re-upload the image to backfill."
|
||||||
|
)
|
||||||
|
|
||||||
|
row = get_image(row_id)
|
||||||
|
if row is None:
|
||||||
|
raise ValueError(f"No DB record for row id {row_id}")
|
||||||
|
|
||||||
|
ext = Path(row["filename"]).suffix # e.g. ".jpg"
|
||||||
|
dest_name = f"ttb_{slot_key}{ext}"
|
||||||
|
input_path = Path(comfy_input_path)
|
||||||
|
input_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Remove any existing file for this slot (may have a different extension)
|
||||||
|
for old in input_path.glob(f"ttb_{slot_key}.*"):
|
||||||
|
try:
|
||||||
|
old.unlink()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
(input_path / dest_name).write_bytes(data)
|
||||||
|
|
||||||
|
# Update DB: clear slot from previous holder, then assign to this row
|
||||||
|
with _connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE input_images SET active_slot_key = NULL WHERE active_slot_key = ?",
|
||||||
|
(slot_key,),
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE input_images SET active_slot_key = ? WHERE id = ?",
|
||||||
|
(slot_key, row_id),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
return dest_name
|
||||||
|
|
||||||
|
|
||||||
|
def deactivate_image_slot(slot_key: str, comfy_input_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Remove the ``ttb_{slot_key}.*`` file from the ComfyUI input folder and
|
||||||
|
clear the matching DB column. Safe no-op if nothing is active for that slot.
|
||||||
|
"""
|
||||||
|
input_path = Path(comfy_input_path)
|
||||||
|
for old in input_path.glob(f"ttb_{slot_key}.*"):
|
||||||
|
try:
|
||||||
|
old.unlink()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with _connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE input_images SET active_slot_key = NULL WHERE active_slot_key = ?",
|
||||||
|
(slot_key,),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def set_active(row_id: int) -> None:
|
||||||
|
"""Mark one image as active and clear the active flag on all others."""
|
||||||
|
with _connect() as conn:
|
||||||
|
conn.execute("UPDATE input_images SET is_active = 0")
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE input_images SET is_active = 1 WHERE id = ?",
|
||||||
|
(row_id,),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def get_image(row_id: int) -> dict | None:
|
||||||
|
"""Return a single image row by its auto-increment id (excluding image_data), or None."""
|
||||||
|
with _connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
f"SELECT {_SAFE_COLS} FROM input_images WHERE id = ?",
|
||||||
|
(row_id,),
|
||||||
|
).fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_images() -> list[dict]:
|
||||||
|
"""Return all image rows as a list of dicts (excluding image_data)."""
|
||||||
|
with _connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
f"SELECT {_SAFE_COLS} FROM input_images"
|
||||||
|
).fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def delete_image(row_id: int, comfy_input_path: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Remove an image record from the database.
|
||||||
|
|
||||||
|
If the image is currently active for a slot and *comfy_input_path* is
|
||||||
|
provided, the corresponding ``ttb_{slot_key}.*`` file is also deleted.
|
||||||
|
"""
|
||||||
|
row = get_image(row_id)
|
||||||
|
if row and row.get("active_slot_key") and comfy_input_path:
|
||||||
|
deactivate_image_slot(row["active_slot_key"], comfy_input_path)
|
||||||
|
|
||||||
|
with _connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM input_images WHERE id = ?",
|
||||||
|
(row_id,),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
286
media_uploader.py
Normal file
286
media_uploader.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""
|
||||||
|
media_uploader.py
|
||||||
|
=================
|
||||||
|
|
||||||
|
Auto-uploads generated media files to the external storage server.
|
||||||
|
|
||||||
|
On success the local file is deleted. Any files that fail to upload
|
||||||
|
are left in place and will be retried automatically on the next call
|
||||||
|
(i.e. the next time a generation completes).
|
||||||
|
|
||||||
|
If no credentials are configured (MEDIA_UPLOAD_USER / MEDIA_UPLOAD_PASS
|
||||||
|
env vars not set), flush_pending() is a no-op and files are left for the
|
||||||
|
manual ttr!collect-videos command.
|
||||||
|
|
||||||
|
Upload behaviour:
|
||||||
|
- Files are categorised into image / gif / video / audio folders.
|
||||||
|
- A ``folder`` form field is sent with each upload so the server can
|
||||||
|
route the file into the correct subdirectory.
|
||||||
|
- The current datetime is appended to each filename before uploading
|
||||||
|
(e.g. ``output_20260225_143022.png``); the local filename is unchanged.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
from media_uploader import flush_pending, get_stats
|
||||||
|
await flush_pending(Path(config.comfy_output_path),
|
||||||
|
config.media_upload_user,
|
||||||
|
config.media_upload_pass)
|
||||||
|
stats = get_stats()
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import ssl
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Constants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
UPLOAD_URL = "https://mediaup.revoluxiant.ddns.net/upload"
|
||||||
|
|
||||||
|
# Media categories and their recognised extensions.
|
||||||
|
CATEGORY_EXTENSIONS: dict[str, frozenset[str]] = {
|
||||||
|
"image": frozenset({
|
||||||
|
".png", ".jpg", ".jpeg", ".webp", ".bmp",
|
||||||
|
".tiff", ".tif", ".avif", ".heic", ".heif", ".svg", ".ico",
|
||||||
|
}),
|
||||||
|
"gif": frozenset({
|
||||||
|
".gif",
|
||||||
|
}),
|
||||||
|
"video": frozenset({
|
||||||
|
".mp4", ".webm", ".avi", ".mkv", ".mov",
|
||||||
|
".flv", ".ts", ".m2ts", ".m4v", ".wmv",
|
||||||
|
}),
|
||||||
|
"audio": frozenset({
|
||||||
|
".mp3", ".wav", ".ogg", ".flac", ".aac",
|
||||||
|
".m4a", ".opus", ".wma", ".aiff", ".aif",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Flat set of all recognised extensions (used for directory scanning).
|
||||||
|
MEDIA_EXTENSIONS: frozenset[str] = frozenset().union(*CATEGORY_EXTENSIONS.values())
|
||||||
|
|
||||||
|
# Shared SSL context — server uses a self-signed cert.
|
||||||
|
_ssl_ctx = ssl.create_default_context()
|
||||||
|
_ssl_ctx.check_hostname = False
|
||||||
|
_ssl_ctx.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
# Prevents concurrent flush runs from uploading the same file twice.
|
||||||
|
_flush_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Stats
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UploadStats:
|
||||||
|
"""Cumulative upload counters for the current bot session."""
|
||||||
|
total_attempted: int = 0
|
||||||
|
total_ok: int = 0
|
||||||
|
last_attempted: int = 0
|
||||||
|
last_ok: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_fail(self) -> int:
|
||||||
|
return self.total_attempted - self.total_ok
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_fail(self) -> int:
|
||||||
|
return self.last_attempted - self.last_ok
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fail_rate_pct(self) -> float:
|
||||||
|
if self.total_attempted == 0:
|
||||||
|
return 0.0
|
||||||
|
return (self.total_fail / self.total_attempted) * 100.0
|
||||||
|
|
||||||
|
|
||||||
|
_stats = UploadStats()
|
||||||
|
|
||||||
|
|
||||||
|
def get_stats() -> UploadStats:
|
||||||
|
"""Return the module-level upload stats (live reference)."""
|
||||||
|
return _stats
|
||||||
|
|
||||||
|
|
||||||
|
def is_running() -> bool:
|
||||||
|
"""Return True if a flush is currently in progress."""
|
||||||
|
return _flush_lock.locked()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _content_type(filepath: Path) -> str:
|
||||||
|
mime, _ = mimetypes.guess_type(filepath.name)
|
||||||
|
return mime or "application/octet-stream"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_category(suffix: str) -> str:
|
||||||
|
"""Return the upload folder category for a file extension."""
|
||||||
|
s = suffix.lower()
|
||||||
|
for category, extensions in CATEGORY_EXTENSIONS.items():
|
||||||
|
if s in extensions:
|
||||||
|
logger.info(f"[_get_category] File category: {category}")
|
||||||
|
return category
|
||||||
|
|
||||||
|
logger.info(f"[_get_category] File category: other")
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_upload_name(filepath: Path) -> str:
|
||||||
|
"""Return a filename with the current datetime appended before the extension.
|
||||||
|
|
||||||
|
Example: ``ComfyUI_00042.png`` → ``20260225_143022_ComfyUI_00042.png``
|
||||||
|
"""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
return f"{timestamp}_{filepath.stem}{filepath.suffix}"
|
||||||
|
|
||||||
|
def _build_url(url, category: str) -> str:
|
||||||
|
"""Build url to upload to each specific folder"""
|
||||||
|
url += "/comfyui"
|
||||||
|
if category == "image":
|
||||||
|
url += "/image"
|
||||||
|
elif category == "video":
|
||||||
|
url += "/video"
|
||||||
|
elif category == "gif":
|
||||||
|
url += "/gif"
|
||||||
|
elif category == "audio":
|
||||||
|
url += "/audio"
|
||||||
|
else:
|
||||||
|
url
|
||||||
|
return url
|
||||||
|
|
||||||
|
async def _upload_one(session: aiohttp.ClientSession, filepath: Path) -> bool:
|
||||||
|
"""Upload a single file. Returns True on HTTP 2xx."""
|
||||||
|
try:
|
||||||
|
file_bytes = filepath.read_bytes()
|
||||||
|
except OSError:
|
||||||
|
logger.warning("Cannot read file for upload: %s", filepath)
|
||||||
|
return False
|
||||||
|
|
||||||
|
category = _get_category(filepath.suffix)
|
||||||
|
upload_name = _build_upload_name(filepath)
|
||||||
|
|
||||||
|
form = aiohttp.FormData()
|
||||||
|
form.add_field(
|
||||||
|
"file",
|
||||||
|
file_bytes,
|
||||||
|
filename=upload_name,
|
||||||
|
content_type=_content_type(filepath),
|
||||||
|
)
|
||||||
|
|
||||||
|
url = _build_url(UPLOAD_URL, category)
|
||||||
|
logger.info(f"Uploading file to url: {url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
data=form,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=120),
|
||||||
|
) as resp:
|
||||||
|
if resp.status // 100 == 2:
|
||||||
|
return True
|
||||||
|
body = await resp.text()
|
||||||
|
logger.warning(
|
||||||
|
"Upload rejected %s: HTTP %s — %s",
|
||||||
|
upload_name,
|
||||||
|
resp.status,
|
||||||
|
body[:200],
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Upload error for %s", upload_name, exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def flush_pending(
|
||||||
|
output_path: Path,
|
||||||
|
user: Optional[str],
|
||||||
|
password: Optional[str],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Scan *output_path* for media files, upload each to the storage server,
|
||||||
|
and delete the local file on success.
|
||||||
|
|
||||||
|
If *user* or *password* is falsy (not configured), returns 0 immediately
|
||||||
|
and leaves all files in place for the manual ttr!collect-videos command.
|
||||||
|
|
||||||
|
Files that fail to upload are left in place and will be retried on the
|
||||||
|
next call. If a previous flush is still running the new call returns 0
|
||||||
|
immediately to avoid double-uploading.
|
||||||
|
|
||||||
|
Returns the number of files successfully uploaded and deleted.
|
||||||
|
"""
|
||||||
|
if not user or not password:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if _flush_lock.locked():
|
||||||
|
logger.debug("flush_pending already in progress, skipping")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async with _flush_lock:
|
||||||
|
try:
|
||||||
|
entries = [
|
||||||
|
e for e in output_path.iterdir()
|
||||||
|
if e.is_file() and e.suffix.lower() in MEDIA_EXTENSIONS
|
||||||
|
]
|
||||||
|
except OSError:
|
||||||
|
logger.warning("Cannot scan output directory: %s", output_path)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
_stats.last_attempted = 0
|
||||||
|
_stats.last_ok = 0
|
||||||
|
return 0
|
||||||
|
|
||||||
|
logger.info("Auto-uploading %d pending media file(s)…", len(entries))
|
||||||
|
|
||||||
|
auth = aiohttp.BasicAuth(user, password)
|
||||||
|
connector = aiohttp.TCPConnector(ssl=_ssl_ctx)
|
||||||
|
uploaded = 0
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(connector=connector, auth=auth) as session:
|
||||||
|
for filepath in entries:
|
||||||
|
if await _upload_one(session, filepath):
|
||||||
|
try:
|
||||||
|
filepath.unlink()
|
||||||
|
logger.info("Uploaded and deleted: %s", filepath.name)
|
||||||
|
uploaded += 1
|
||||||
|
except OSError:
|
||||||
|
logger.warning(
|
||||||
|
"Uploaded but could not delete local file: %s", filepath
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cumulative stats
|
||||||
|
_stats.last_attempted = len(entries)
|
||||||
|
_stats.last_ok = uploaded
|
||||||
|
_stats.total_attempted += len(entries)
|
||||||
|
_stats.total_ok += uploaded
|
||||||
|
|
||||||
|
failed = len(entries) - uploaded
|
||||||
|
if failed:
|
||||||
|
logger.warning(
|
||||||
|
"Auto-upload: %d ok, %d failed — will retry next generation.",
|
||||||
|
uploaded,
|
||||||
|
failed,
|
||||||
|
)
|
||||||
|
elif uploaded:
|
||||||
|
logger.info("Auto-upload complete: %d file(s) uploaded and deleted.", uploaded)
|
||||||
|
|
||||||
|
return uploaded
|
||||||
206
preset_manager.py
Normal file
206
preset_manager.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
preset_manager.py
|
||||||
|
=================
|
||||||
|
|
||||||
|
Preset management for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
A preset is a named snapshot of the current workflow template and runtime
|
||||||
|
state (prompt, negative_prompt, input_image, seed). Presets are stored as
|
||||||
|
individual JSON files inside a ``presets/`` directory so they can be
|
||||||
|
inspected, backed up, or shared manually.
|
||||||
|
|
||||||
|
Usage via bot commands:
|
||||||
|
ttr!preset save <name> — capture current workflow + state
|
||||||
|
ttr!preset load <name> — restore workflow + state from snapshot
|
||||||
|
ttr!preset list — list all saved presets
|
||||||
|
ttr!preset delete <name> — permanently remove a preset
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Only allow alphanumeric characters, hyphens, and underscores in preset names.
|
||||||
|
_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
||||||
|
|
||||||
|
|
||||||
|
class PresetManager:
|
||||||
|
"""
|
||||||
|
Manages named workflow presets on disk.
|
||||||
|
|
||||||
|
Each preset is stored as ``<presets_dir>/<name>.json`` and contains:
|
||||||
|
- ``name``: the preset name
|
||||||
|
- ``workflow``: the full ComfyUI workflow template dict (may be null)
|
||||||
|
- ``state``: the runtime state changes (prompt, negative_prompt,
|
||||||
|
input_image, seed)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
presets_dir : str
|
||||||
|
Directory where preset files are stored. Created automatically if
|
||||||
|
it does not exist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, presets_dir: str = "presets") -> None:
|
||||||
|
self.presets_dir = Path(presets_dir)
|
||||||
|
self.presets_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_valid_name(name: str) -> bool:
|
||||||
|
"""Return True if the name only contains safe characters."""
|
||||||
|
return bool(_SAFE_NAME_RE.match(name))
|
||||||
|
|
||||||
|
def _path(self, name: str) -> Path:
|
||||||
|
"""Return the file path for a preset by name (no validation)."""
|
||||||
|
return self.presets_dir / f"{name}.json"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
workflow_template: Optional[dict[str, Any]],
|
||||||
|
state: dict[str, Any],
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save a preset to disk.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The preset name (alphanumeric, hyphens, underscores only).
|
||||||
|
workflow_template : Optional[dict]
|
||||||
|
The current workflow template, or None if none is loaded.
|
||||||
|
state : dict
|
||||||
|
The current runtime state from WorkflowStateManager.get_changes().
|
||||||
|
owner : Optional[str]
|
||||||
|
The user label of the preset creator. Stored for access control.
|
||||||
|
description : Optional[str]
|
||||||
|
A human-readable description of what this preset does.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the name contains invalid characters.
|
||||||
|
OSError
|
||||||
|
If the file cannot be written.
|
||||||
|
"""
|
||||||
|
if not self.is_valid_name(name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid preset name '{name}'. "
|
||||||
|
"Use only letters, digits, hyphens, and underscores (max 64 chars)."
|
||||||
|
)
|
||||||
|
data: dict[str, Any] = {"name": name, "workflow": workflow_template, "state": state}
|
||||||
|
if owner is not None:
|
||||||
|
data["owner"] = owner
|
||||||
|
if description is not None:
|
||||||
|
data["description"] = description
|
||||||
|
path = self._path(name)
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
logger.info("Saved preset '%s' to %s", name, path)
|
||||||
|
|
||||||
|
def load(self, name: str) -> Optional[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Load a preset from disk.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The preset name.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Optional[dict]
|
||||||
|
The preset data dict, or None if the preset does not exist.
|
||||||
|
"""
|
||||||
|
if not self.is_valid_name(name):
|
||||||
|
return None
|
||||||
|
path = self._path(name)
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to load preset '%s': %s", name, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a preset file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The preset name.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the preset existed and was deleted, False otherwise.
|
||||||
|
"""
|
||||||
|
if not self.is_valid_name(name):
|
||||||
|
return False
|
||||||
|
path = self._path(name)
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
logger.info("Deleted preset '%s'", name)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_presets(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all saved preset names, sorted alphabetically.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[str]
|
||||||
|
Sorted list of preset names (without the .json extension).
|
||||||
|
"""
|
||||||
|
return sorted(p.stem for p in self.presets_dir.glob("*.json"))
|
||||||
|
|
||||||
|
def list_preset_details(self) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all presets with their metadata, sorted alphabetically by name.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[dict]
|
||||||
|
Each entry has ``"name"``, ``"owner"`` (may be None), and
|
||||||
|
``"description"`` (may be None).
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for p in sorted(self.presets_dir.glob("*.json"), key=lambda x: x.stem):
|
||||||
|
owner = None
|
||||||
|
description = None
|
||||||
|
try:
|
||||||
|
with open(p, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
owner = data.get("owner")
|
||||||
|
description = data.get("description")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
result.append({"name": p.stem, "owner": owner, "description": description})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def exists(self, name: str) -> bool:
|
||||||
|
"""Return True if a preset with this name exists on disk."""
|
||||||
|
if not self.is_valid_name(name):
|
||||||
|
return False
|
||||||
|
return self._path(name).exists()
|
||||||
521
status_monitor.py
Normal file
521
status_monitor.py
Normal file
@@ -0,0 +1,521 @@
|
|||||||
|
"""
|
||||||
|
status_monitor.py
|
||||||
|
=================
|
||||||
|
|
||||||
|
Live status dashboard for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
Edits a single pinned message in a designated log channel every
|
||||||
|
``update_interval`` seconds. Changed values are highlighted with
|
||||||
|
bold text and directional arrows/emoji so differences are immediately
|
||||||
|
obvious.
|
||||||
|
|
||||||
|
Change-highlighting rules
|
||||||
|
-------------------------
|
||||||
|
- Unchanged good state → 🟢 value
|
||||||
|
- Unchanged bad state → 🔴 value
|
||||||
|
- Changed → bad → ⚠️ **value**
|
||||||
|
- Changed → good → ✅ **value**
|
||||||
|
- Changed (neutral) → **value**
|
||||||
|
- Queue size changed → **N** ▲ or **N** ▼
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from commands.server import get_service_state, STATUS_EMOJI
|
||||||
|
from media_uploader import get_stats as get_upload_stats, is_running as upload_is_running, MEDIA_EXTENSIONS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level helpers (no discord.py dependency)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _format_uptime(start: datetime) -> str:
|
||||||
|
"""Return a human-readable uptime string from a UTC start time."""
|
||||||
|
delta = datetime.now(timezone.utc) - start
|
||||||
|
total = int(delta.total_seconds())
|
||||||
|
h, rem = divmod(total, 3600)
|
||||||
|
m, s = divmod(rem, 60)
|
||||||
|
if h:
|
||||||
|
return f"{h}h {m}m {s}s"
|
||||||
|
if m:
|
||||||
|
return f"{m}m {s}s"
|
||||||
|
return f"{s}s"
|
||||||
|
|
||||||
|
|
||||||
|
def _elapsed(start: datetime) -> str:
|
||||||
|
"""Return elapsed time string since *start* (UTC)."""
|
||||||
|
delta = datetime.now(timezone.utc) - start
|
||||||
|
total = int(delta.total_seconds())
|
||||||
|
h, rem = divmod(total, 3600)
|
||||||
|
m, s = divmod(rem, 60)
|
||||||
|
if h:
|
||||||
|
return f"{h}:{m:02d}:{s:02d}"
|
||||||
|
return f"{m}:{s:02d}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# StatusMonitor
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class StatusMonitor:
|
||||||
|
"""
|
||||||
|
Periodically edits a single Discord message with live bot status.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bot :
|
||||||
|
The discord.ext.commands.Bot instance.
|
||||||
|
channel_id : int
|
||||||
|
ID of the Discord channel used for the dashboard message.
|
||||||
|
update_interval : float
|
||||||
|
Seconds between updates (default 5).
|
||||||
|
"""
|
||||||
|
|
||||||
|
HEADER_MARKER = "📊"
|
||||||
|
|
||||||
|
def __init__(self, bot, channel_id: int, update_interval: float = 10.0) -> None:
|
||||||
|
self._bot = bot
|
||||||
|
self._channel_id = channel_id
|
||||||
|
self._interval = update_interval
|
||||||
|
self._prev: dict[str, str] = {}
|
||||||
|
self._message: Optional[discord.Message] = None
|
||||||
|
self._task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public lifecycle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the update loop (idempotent)."""
|
||||||
|
if self._task is None or self._task.done():
|
||||||
|
self._task = asyncio.create_task(self._update_loop())
|
||||||
|
logger.info("StatusMonitor started for channel %s", self._channel_id)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Cancel the update loop, then send shutdown notice (idempotent)."""
|
||||||
|
if self._task and not self._task.done():
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
await self._send_shutdown_message()
|
||||||
|
logger.info("StatusMonitor stopped")
|
||||||
|
|
||||||
|
async def _send_shutdown_message(self) -> None:
|
||||||
|
"""Immediately edit the dashboard message to show bot offline status.
|
||||||
|
|
||||||
|
Uses a fresh aiohttp session with the bot token directly, because
|
||||||
|
discord.py closes its own HTTP session before our finally block runs
|
||||||
|
on Ctrl-C / task cancellation, making Message.edit() silently fail.
|
||||||
|
"""
|
||||||
|
if self._message is None:
|
||||||
|
# No cached message; can't create one safely during shutdown.
|
||||||
|
return
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
vn_time = now + timedelta(hours=7)
|
||||||
|
utc_str = now.strftime("%H:%M:%S UTC")
|
||||||
|
vn_str = vn_time.strftime("%H:%M:%S GMT+7")
|
||||||
|
text = (
|
||||||
|
f"{self.HEADER_MARKER} 🔴 **Bot Status Dashboard** — OFFLINE\n"
|
||||||
|
f"-# Shut down at: {utc_str} ({vn_str})\n"
|
||||||
|
"\n"
|
||||||
|
"Bot process has stopped."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
token = self._bot.http.token
|
||||||
|
url = (
|
||||||
|
f"https://discord.com/api/v10/channels/"
|
||||||
|
f"{self._channel_id}/messages/{self._message.id}"
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bot {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.patch(url, json={"content": text}, headers=headers) as resp:
|
||||||
|
if resp.status not in (200, 204):
|
||||||
|
logger.warning(
|
||||||
|
"StatusMonitor: shutdown edit returned HTTP %s", resp.status
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("StatusMonitor: could not send shutdown message: %s", exc)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _get_or_create_message(self) -> Optional[discord.Message]:
|
||||||
|
"""
|
||||||
|
Return the existing dashboard message or create a new one.
|
||||||
|
|
||||||
|
Searches the last 20 messages in the channel for an existing
|
||||||
|
dashboard (posted by this bot and containing the header marker).
|
||||||
|
"""
|
||||||
|
channel = self._bot.get_channel(self._channel_id)
|
||||||
|
if channel is None:
|
||||||
|
try:
|
||||||
|
channel = await self._bot.fetch_channel(self._channel_id)
|
||||||
|
except discord.NotFound:
|
||||||
|
logger.error("StatusMonitor: channel %s not found", self._channel_id)
|
||||||
|
return None
|
||||||
|
except discord.Forbidden:
|
||||||
|
logger.error("StatusMonitor: no access to channel %s", self._channel_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try to find an existing dashboard message
|
||||||
|
try:
|
||||||
|
async for msg in channel.history(limit=20):
|
||||||
|
if msg.author == self._bot.user and self.HEADER_MARKER in msg.content:
|
||||||
|
return msg
|
||||||
|
except discord.HTTPException as exc:
|
||||||
|
logger.warning("StatusMonitor: history fetch failed: %s", exc)
|
||||||
|
|
||||||
|
# None found — create a fresh one
|
||||||
|
try:
|
||||||
|
msg = await channel.send(f"{self.HEADER_MARKER} **Bot Status Dashboard**\n-# Initializing…")
|
||||||
|
return msg
|
||||||
|
except discord.HTTPException as exc:
|
||||||
|
logger.error("StatusMonitor: could not create dashboard message: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _collect_sync(self) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Read bot/workflow state synchronously (no async calls).
|
||||||
|
|
||||||
|
Returns a flat dict of string key → string value snapshots.
|
||||||
|
ComfyUI queue stats are filled in asynchronously in _update_loop.
|
||||||
|
"""
|
||||||
|
snap: dict[str, str] = {}
|
||||||
|
bot = self._bot
|
||||||
|
|
||||||
|
# --- Bot section ---
|
||||||
|
lat = bot.latency
|
||||||
|
latency_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
|
||||||
|
snap["latency"] = f"{latency_ms} ms"
|
||||||
|
|
||||||
|
if hasattr(bot, "start_time"):
|
||||||
|
snap["uptime"] = _format_uptime(bot.start_time)
|
||||||
|
else:
|
||||||
|
snap["uptime"] = "unknown"
|
||||||
|
|
||||||
|
# --- ComfyUI queue section (filled async) ---
|
||||||
|
snap["comfy_pending"] = self._prev.get("comfy_pending", "?")
|
||||||
|
snap["comfy_running"] = self._prev.get("comfy_running", "?")
|
||||||
|
|
||||||
|
# --- ComfyUI section ---
|
||||||
|
comfy = getattr(bot, "comfy", None)
|
||||||
|
if comfy is not None:
|
||||||
|
snap["comfy_server"] = getattr(comfy, "server_address", "unknown")
|
||||||
|
wm = getattr(comfy, "workflow_manager", None)
|
||||||
|
workflow_loaded = wm is not None and wm.get_workflow_template() is not None
|
||||||
|
snap["workflow"] = "loaded" if workflow_loaded else "none"
|
||||||
|
|
||||||
|
sm = getattr(comfy, "state_manager", None)
|
||||||
|
if sm is not None:
|
||||||
|
changes = sm.get_changes()
|
||||||
|
p = changes.get("prompt") or ""
|
||||||
|
snap["prompt"] = (p[:50] + "…" if len(p) > 50 else p) if p else "—"
|
||||||
|
n = changes.get("negative_prompt") or ""
|
||||||
|
snap["neg_prompt"] = (n[:50] + "…" if len(n) > 50 else n) if n else "—"
|
||||||
|
img = changes.get("input_image")
|
||||||
|
snap["input_image"] = Path(img).name if img else "—"
|
||||||
|
seed_pin = changes.get("seed")
|
||||||
|
snap["pinned_seed"] = str(seed_pin) if seed_pin is not None else "random"
|
||||||
|
else:
|
||||||
|
snap["prompt"] = "—"
|
||||||
|
snap["neg_prompt"] = "—"
|
||||||
|
snap["input_image"] = "—"
|
||||||
|
snap["pinned_seed"] = "—"
|
||||||
|
|
||||||
|
last_seed = getattr(comfy, "last_seed", None)
|
||||||
|
snap["last_seed"] = str(last_seed) if last_seed is not None else "—"
|
||||||
|
snap["total_gen"] = str(getattr(comfy, "total_generated", 0))
|
||||||
|
else:
|
||||||
|
snap["comfy_server"] = "not configured"
|
||||||
|
snap["workflow"] = "—"
|
||||||
|
snap["prompt"] = "—"
|
||||||
|
snap["neg_prompt"] = "—"
|
||||||
|
snap["input_image"] = "—"
|
||||||
|
snap["pinned_seed"] = "—"
|
||||||
|
snap["last_seed"] = "—"
|
||||||
|
snap["total_gen"] = "0"
|
||||||
|
|
||||||
|
# comfy_status and service_state are filled in asynchronously
|
||||||
|
snap["comfy_status"] = self._prev.get("comfy_status", "unknown")
|
||||||
|
snap["service_state"] = self._prev.get("service_state", "unknown")
|
||||||
|
|
||||||
|
# --- Auto-upload section ---
|
||||||
|
config = getattr(bot, "config", None)
|
||||||
|
upload_user = getattr(config, "media_upload_user", None)
|
||||||
|
upload_configured = bool(upload_user)
|
||||||
|
snap["upload_configured"] = "enabled" if upload_configured else "disabled"
|
||||||
|
|
||||||
|
if upload_configured:
|
||||||
|
snap["upload_state"] = "uploading" if upload_is_running() else "idle"
|
||||||
|
|
||||||
|
# Pending: count media files sitting in the output directory
|
||||||
|
output_path_str = getattr(config, "comfy_output_path", None)
|
||||||
|
if output_path_str:
|
||||||
|
try:
|
||||||
|
pending = sum(
|
||||||
|
1 for e in Path(output_path_str).iterdir()
|
||||||
|
if e.is_file() and e.suffix.lower() in MEDIA_EXTENSIONS
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
pending = 0
|
||||||
|
snap["upload_pending"] = str(pending)
|
||||||
|
else:
|
||||||
|
snap["upload_pending"] = "—"
|
||||||
|
|
||||||
|
us = get_upload_stats()
|
||||||
|
if us.total_attempted > 0:
|
||||||
|
snap["upload_session"] = (
|
||||||
|
f"{us.total_ok} ok, {us.total_fail} failed"
|
||||||
|
f" ({us.fail_rate_pct:.1f}%)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
snap["upload_session"] = "no uploads yet"
|
||||||
|
|
||||||
|
if us.last_attempted > 0:
|
||||||
|
snap["upload_last"] = f"{us.last_ok} ok, {us.last_fail} failed"
|
||||||
|
else:
|
||||||
|
snap["upload_last"] = "—"
|
||||||
|
else:
|
||||||
|
snap["upload_state"] = "—"
|
||||||
|
snap["upload_pending"] = "—"
|
||||||
|
snap["upload_session"] = "—"
|
||||||
|
snap["upload_last"] = "—"
|
||||||
|
|
||||||
|
return snap
|
||||||
|
|
||||||
|
async def _check_connection(self) -> str:
|
||||||
|
"""Async check whether ComfyUI is reachable. Returns a plain string."""
|
||||||
|
comfy = getattr(self._bot, "comfy", None)
|
||||||
|
if comfy is None:
|
||||||
|
return "not configured"
|
||||||
|
try:
|
||||||
|
reachable = await asyncio.wait_for(comfy.check_connection(), timeout=4.0)
|
||||||
|
return "reachable" if reachable else "unreachable"
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return "unreachable"
|
||||||
|
except Exception:
|
||||||
|
return "unreachable"
|
||||||
|
|
||||||
|
async def _check_comfy_queue(self) -> dict[str, str]:
|
||||||
|
"""Fetch ComfyUI queue depths. Returns {comfy_pending, comfy_running}."""
|
||||||
|
comfy = getattr(self._bot, "comfy", None)
|
||||||
|
if comfy is None:
|
||||||
|
return {"comfy_pending": "?", "comfy_running": "?"}
|
||||||
|
try:
|
||||||
|
q = await asyncio.wait_for(comfy.get_comfy_queue(), timeout=4.0)
|
||||||
|
if q:
|
||||||
|
return {
|
||||||
|
"comfy_pending": str(len(q.get("queue_pending", []))),
|
||||||
|
"comfy_running": str(len(q.get("queue_running", []))),
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {"comfy_pending": "?", "comfy_running": "?"}
|
||||||
|
|
||||||
|
async def _check_service_state(self) -> str:
|
||||||
|
"""Return the NSSM service state string for the configured ComfyUI service."""
|
||||||
|
config = getattr(self._bot, "config", None)
|
||||||
|
if config is None:
|
||||||
|
return "unknown"
|
||||||
|
service_name = getattr(config, "comfy_service_name", None)
|
||||||
|
if not service_name:
|
||||||
|
return "unknown"
|
||||||
|
return await get_service_state(service_name)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Change-detection formatting
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _fmt(self, key: str, value: str, *, good: str, bad: str) -> str:
|
||||||
|
"""
|
||||||
|
Format *value* with change-detection highlighting.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key :
|
||||||
|
Snapshot key used to look up the previous value.
|
||||||
|
value :
|
||||||
|
Current value string.
|
||||||
|
good :
|
||||||
|
The value string that represents a "good" state.
|
||||||
|
bad :
|
||||||
|
The value string that represents a "bad" state.
|
||||||
|
"""
|
||||||
|
prev = self._prev.get(key)
|
||||||
|
changed = prev is not None and prev != value
|
||||||
|
is_good = value == good
|
||||||
|
is_bad = value == bad
|
||||||
|
|
||||||
|
if not changed:
|
||||||
|
if is_good:
|
||||||
|
return f"🟢 {value}"
|
||||||
|
if is_bad:
|
||||||
|
return f"🔴 {value}"
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Value changed
|
||||||
|
if is_bad:
|
||||||
|
return f"⚠️ **{value}**"
|
||||||
|
if is_good:
|
||||||
|
return f"✅ **{value}**"
|
||||||
|
return f"**{value}**"
|
||||||
|
|
||||||
|
def _fmt_service_state(self, value: str) -> str:
|
||||||
|
"""Format NSSM service state with emoji and change-detection highlighting."""
|
||||||
|
prev = self._prev.get("service_state")
|
||||||
|
changed = prev is not None and prev != value
|
||||||
|
emoji = STATUS_EMOJI.get(value, "⚪")
|
||||||
|
if not changed:
|
||||||
|
return f"{emoji} {value}"
|
||||||
|
if value == "SERVICE_RUNNING":
|
||||||
|
return f"✅ **{value}**"
|
||||||
|
if value in ("SERVICE_STOPPED", "error", "timeout"):
|
||||||
|
return f"⚠️ **{value}**"
|
||||||
|
return f"**{value}**"
|
||||||
|
|
||||||
|
def _fmt_queue_size(self, value: str, prev_key: str) -> str:
|
||||||
|
"""Format queue size with ▲/▼ arrows when changed."""
|
||||||
|
prev = self._prev.get(prev_key)
|
||||||
|
if prev is None or prev == value:
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
arrow = "▲" if int(value) > int(prev) else "▼"
|
||||||
|
except ValueError:
|
||||||
|
arrow = ""
|
||||||
|
return f"**{value}** {arrow}" if arrow else f"**{value}**"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Message assembly
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_message(self, snap: dict[str, str], now: datetime) -> str:
|
||||||
|
"""Assemble the full dashboard message string."""
|
||||||
|
vn_time = now + timedelta(hours=7)
|
||||||
|
timestamp = f"{now.strftime('%H:%M:%S UTC')} ({vn_time.strftime('%H:%M:%S GMT+7')})"
|
||||||
|
|
||||||
|
pending_fmt = self._fmt_queue_size(snap["comfy_pending"], "comfy_pending")
|
||||||
|
running_fmt = self._fmt_queue_size(snap["comfy_running"], "comfy_running")
|
||||||
|
http_fmt = self._fmt("comfy_status", snap["comfy_status"], good="reachable", bad="unreachable")
|
||||||
|
svc_fmt = self._fmt_service_state(snap["service_state"])
|
||||||
|
seed_fmt = self._fmt("last_seed", snap["last_seed"], good="", bad="")
|
||||||
|
prompt_fmt = self._fmt("prompt", snap["prompt"], good="", bad="")
|
||||||
|
neg_fmt = self._fmt("neg_prompt", snap["neg_prompt"], good="", bad="")
|
||||||
|
image_fmt = self._fmt("input_image", snap["input_image"], good="", bad="")
|
||||||
|
pinned_fmt = self._fmt("pinned_seed", snap["pinned_seed"], good="", bad="")
|
||||||
|
|
||||||
|
upload_state_fmt = self._fmt(
|
||||||
|
"upload_state", snap["upload_state"], good="idle", bad=""
|
||||||
|
)
|
||||||
|
upload_pending_fmt = self._fmt(
|
||||||
|
"upload_pending", snap["upload_pending"], good="0", bad=""
|
||||||
|
)
|
||||||
|
upload_session_fmt = self._fmt("upload_session", snap["upload_session"], good="", bad="")
|
||||||
|
upload_last_fmt = self._fmt("upload_last", snap["upload_last"], good="", bad="")
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"{self.HEADER_MARKER} **Bot Status Dashboard**",
|
||||||
|
f"-# Last updated: {timestamp}",
|
||||||
|
"",
|
||||||
|
"**Bot**",
|
||||||
|
f" Latency : {snap['latency']}",
|
||||||
|
f" Uptime : {snap['uptime']}",
|
||||||
|
"",
|
||||||
|
f"**ComfyUI** — `{snap['comfy_server']}`",
|
||||||
|
f" Service : {svc_fmt}",
|
||||||
|
f" HTTP : {http_fmt}",
|
||||||
|
f" Queue : {running_fmt} running, {pending_fmt} pending",
|
||||||
|
f" Workflow : {snap['workflow']}",
|
||||||
|
f" Prompt : || {prompt_fmt} ||",
|
||||||
|
f" Neg : || {neg_fmt} ||",
|
||||||
|
f" Image : {image_fmt}",
|
||||||
|
f" Seed : {pinned_fmt}",
|
||||||
|
"",
|
||||||
|
"**Last Generation**",
|
||||||
|
f" Seed : {seed_fmt}",
|
||||||
|
f" Total : {snap['total_gen']}",
|
||||||
|
]
|
||||||
|
|
||||||
|
if snap["upload_configured"] == "enabled":
|
||||||
|
lines += [
|
||||||
|
"",
|
||||||
|
"**Auto-Upload**",
|
||||||
|
f" State : {upload_state_fmt}",
|
||||||
|
f" Pending : {upload_pending_fmt}",
|
||||||
|
f" Session : {upload_session_fmt}",
|
||||||
|
f" Last run : {upload_last_fmt}",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
lines += [
|
||||||
|
"",
|
||||||
|
"**Auto-Upload** — disabled *(set MEDIA_UPLOAD_USER / MEDIA_UPLOAD_PASS to enable)*",
|
||||||
|
]
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Update loop
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _update_loop(self) -> None:
|
||||||
|
"""Background task: collect state, build message, edit in place."""
|
||||||
|
await self._bot.wait_until_ready()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Collect synchronous state
|
||||||
|
snap = self._collect_sync()
|
||||||
|
|
||||||
|
# Async checks run concurrently
|
||||||
|
comfy_status, service_state, queue_stats = await asyncio.gather(
|
||||||
|
self._check_connection(),
|
||||||
|
self._check_service_state(),
|
||||||
|
self._check_comfy_queue(),
|
||||||
|
)
|
||||||
|
snap["comfy_status"] = comfy_status
|
||||||
|
snap["service_state"] = service_state
|
||||||
|
snap.update(queue_stats)
|
||||||
|
|
||||||
|
# Build message text
|
||||||
|
text = self._build_message(snap, now)
|
||||||
|
|
||||||
|
# Ensure we have a message to edit
|
||||||
|
if self._message is None:
|
||||||
|
self._message = await self._get_or_create_message()
|
||||||
|
|
||||||
|
if self._message is not None:
|
||||||
|
try:
|
||||||
|
await self._message.edit(content=text)
|
||||||
|
except discord.NotFound:
|
||||||
|
# Message was deleted — recreate next cycle
|
||||||
|
self._message = None
|
||||||
|
except (discord.HTTPException, OSError) as exc:
|
||||||
|
logger.warning("StatusMonitor: edit failed: %s", exc)
|
||||||
|
|
||||||
|
# Save snapshot for next cycle's change detection
|
||||||
|
self._prev = snap
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("StatusMonitor: unexpected error in update loop")
|
||||||
|
|
||||||
|
await asyncio.sleep(self._interval)
|
||||||
162
token_store.py
Normal file
162
token_store.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
token_store.py
|
||||||
|
==============
|
||||||
|
|
||||||
|
Invite-token CRUD for the web UI.
|
||||||
|
|
||||||
|
Tokens are stored as SHA-256 hashes so the plaintext is never at rest.
|
||||||
|
The plaintext token is only returned once (at creation time).
|
||||||
|
|
||||||
|
File format (invite_tokens.json)::
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "uuid",
|
||||||
|
"label": "alice",
|
||||||
|
"hash": "<sha256-hex>",
|
||||||
|
"admin": false,
|
||||||
|
"created_at": "2024-01-01T00:00:00"
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
# Create a token (CLI)
|
||||||
|
python -c "from token_store import create_token; print(create_token('alice'))"
|
||||||
|
|
||||||
|
# Verify a token (used by auth.py)
|
||||||
|
from token_store import verify_token
|
||||||
|
record = verify_token(plaintext, token_file)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _hash(token: str) -> str:
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _load(token_file: str) -> list[dict]:
|
||||||
|
path = Path(token_file)
|
||||||
|
if not path.exists():
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to load token file %s: %s", token_file, exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _save(token_file: str, records: list[dict]) -> None:
|
||||||
|
with open(token_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(records, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def create_token(
|
||||||
|
label: str,
|
||||||
|
token_file: str = "invite_tokens.json",
|
||||||
|
*,
|
||||||
|
admin: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create a new invite token, persist its hash, and return the plaintext.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
Human-readable name for this token (e.g. ``"alice"``).
|
||||||
|
token_file : str
|
||||||
|
Path to the JSON store.
|
||||||
|
admin : bool
|
||||||
|
If True, grants admin privileges.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
The plaintext token (shown once; not stored).
|
||||||
|
"""
|
||||||
|
plaintext = secrets.token_urlsafe(32)
|
||||||
|
record = {
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"label": label,
|
||||||
|
"hash": _hash(plaintext),
|
||||||
|
"admin": admin,
|
||||||
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
records = _load(token_file)
|
||||||
|
records.append(record)
|
||||||
|
_save(token_file, records)
|
||||||
|
logger.info("Created token for '%s' (admin=%s)", label, admin)
|
||||||
|
return plaintext
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token(
|
||||||
|
plaintext: str,
|
||||||
|
token_file: str = "invite_tokens.json",
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Verify a plaintext token against the stored hashes.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
plaintext : str
|
||||||
|
The token string provided by the user.
|
||||||
|
token_file : str
|
||||||
|
Path to the JSON store.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Optional[dict]
|
||||||
|
The matching record dict (with ``label`` and ``admin`` fields),
|
||||||
|
or None if no match.
|
||||||
|
"""
|
||||||
|
h = _hash(plaintext)
|
||||||
|
for record in _load(token_file):
|
||||||
|
if record.get("hash") == h:
|
||||||
|
return record
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def list_tokens(token_file: str = "invite_tokens.json") -> list[dict]:
|
||||||
|
"""Return all token records (hashes included, labels safe to show)."""
|
||||||
|
return _load(token_file)
|
||||||
|
|
||||||
|
|
||||||
|
def revoke_token(token_id: str, token_file: str = "invite_tokens.json") -> bool:
|
||||||
|
"""
|
||||||
|
Delete a token by its UUID.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if found and deleted, False if not found.
|
||||||
|
"""
|
||||||
|
records = _load(token_file)
|
||||||
|
new_records = [r for r in records if r.get("id") != token_id]
|
||||||
|
if len(new_records) == len(records):
|
||||||
|
return False
|
||||||
|
_save(token_file, new_records)
|
||||||
|
logger.info("Revoked token %s", token_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
label = sys.argv[1] if len(sys.argv) > 1 else "default"
|
||||||
|
is_admin = "--admin" in sys.argv
|
||||||
|
tok = create_token(label, admin=is_admin)
|
||||||
|
print(f"Token for '{label}': {tok}")
|
||||||
147
user_state_registry.py
Normal file
147
user_state_registry.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
user_state_registry.py
|
||||||
|
======================
|
||||||
|
|
||||||
|
Per-user workflow state + template registry.
|
||||||
|
|
||||||
|
Each web-UI user gets their own isolated WorkflowStateManager (persisted to
|
||||||
|
``user_settings/<user_label>.json``) and workflow template.
|
||||||
|
|
||||||
|
New users (no saved file) fall back to the global default workflow template
|
||||||
|
loaded at startup (WORKFLOW_FILE env var or last-used workflow from the
|
||||||
|
global Discord state manager).
|
||||||
|
|
||||||
|
Discord continues to use the shared global state/workflow manager — this
|
||||||
|
registry is only used by the web UI layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from workflow_state import WorkflowStateManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
class UserStateRegistry:
|
||||||
|
"""
|
||||||
|
Per-user isolated workflow state and template store.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
settings_dir : Path
|
||||||
|
Directory where per-user state files are stored. Created automatically.
|
||||||
|
default_workflow : Optional[dict]
|
||||||
|
The global default workflow template. Used when a user has no saved
|
||||||
|
``last_workflow_file``, or the file no longer exists.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
settings_dir: Path,
|
||||||
|
default_workflow: Optional[dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
self._settings_dir = settings_dir
|
||||||
|
self._settings_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._default_workflow: Optional[dict[str, Any]] = default_workflow
|
||||||
|
# user_label → WorkflowStateManager
|
||||||
|
self._managers: Dict[str, WorkflowStateManager] = {}
|
||||||
|
# user_label → workflow template dict (or None)
|
||||||
|
self._templates: Dict[str, Optional[dict[str, Any]]] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Default workflow
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def set_default_workflow(self, template: Optional[dict[str, Any]]) -> None:
|
||||||
|
"""Update the global fallback workflow (called when bot workflow changes)."""
|
||||||
|
self._default_workflow = template
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# User access
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_state_manager(self, user_label: str) -> WorkflowStateManager:
|
||||||
|
"""Return (or lazily create) the WorkflowStateManager for a user."""
|
||||||
|
if user_label not in self._managers:
|
||||||
|
self._init_user(user_label)
|
||||||
|
return self._managers[user_label]
|
||||||
|
|
||||||
|
def get_workflow_template(self, user_label: str) -> Optional[dict[str, Any]]:
|
||||||
|
"""Return the workflow template for a user, or None if not loaded."""
|
||||||
|
if user_label not in self._managers:
|
||||||
|
self._init_user(user_label)
|
||||||
|
return self._templates.get(user_label)
|
||||||
|
|
||||||
|
def set_workflow(
|
||||||
|
self, user_label: str, template: dict[str, Any], filename: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Store a workflow template for a user and persist the filename.
|
||||||
|
|
||||||
|
Clears existing overrides (matches the behaviour of loading a new
|
||||||
|
workflow via the global state manager).
|
||||||
|
"""
|
||||||
|
if user_label not in self._managers:
|
||||||
|
self._init_user(user_label)
|
||||||
|
sm = self._managers[user_label]
|
||||||
|
sm.clear_overrides()
|
||||||
|
sm.set_last_workflow_file(filename)
|
||||||
|
self._templates[user_label] = template
|
||||||
|
logger.debug(
|
||||||
|
"UserStateRegistry: set workflow '%s' for user '%s'", filename, user_label
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internals
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _init_user(self, user_label: str) -> None:
|
||||||
|
"""
|
||||||
|
Initialise state manager and workflow template for a new user.
|
||||||
|
|
||||||
|
1. Create WorkflowStateManager with per-user state file.
|
||||||
|
2. If the file recorded a last_workflow_file, try to load it.
|
||||||
|
3. Fall back to the global default template.
|
||||||
|
"""
|
||||||
|
state_file = str(self._settings_dir / f"{user_label}.json")
|
||||||
|
sm = WorkflowStateManager(state_file=state_file)
|
||||||
|
self._managers[user_label] = sm
|
||||||
|
|
||||||
|
# Try to restore the last workflow this user loaded
|
||||||
|
last_wf = sm.get_last_workflow_file()
|
||||||
|
template: Optional[dict[str, Any]] = None
|
||||||
|
if last_wf:
|
||||||
|
wf_path = _PROJECT_ROOT / "workflows" / last_wf
|
||||||
|
if wf_path.exists():
|
||||||
|
try:
|
||||||
|
with open(wf_path, "r", encoding="utf-8") as f:
|
||||||
|
template = json.load(f)
|
||||||
|
logger.debug(
|
||||||
|
"UserStateRegistry: restored workflow '%s' for user '%s'",
|
||||||
|
last_wf,
|
||||||
|
user_label,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"UserStateRegistry: could not load '%s' for user '%s': %s",
|
||||||
|
last_wf,
|
||||||
|
user_label,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"UserStateRegistry: last workflow '%s' missing for user '%s'; using default",
|
||||||
|
last_wf,
|
||||||
|
user_label,
|
||||||
|
)
|
||||||
|
|
||||||
|
if template is None:
|
||||||
|
template = self._default_workflow
|
||||||
|
self._templates[user_label] = template
|
||||||
309
wan2.2-fast.json
Normal file
309
wan2.2-fast.json
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"inputs": {
|
||||||
|
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||||
|
"type": "wan",
|
||||||
|
"device": "default"
|
||||||
|
},
|
||||||
|
"class_type": "CLIPLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load CLIP"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"inputs": {
|
||||||
|
"vae_name": "wan_2.1_vae.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load VAE"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Diffusion Model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Diffusion Model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"shift": 5.000000000000001,
|
||||||
|
"model": [
|
||||||
|
"13",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ModelSamplingSD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ModelSamplingSD3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"fps": 16,
|
||||||
|
"images": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CreateVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Create Video"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"2",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "disable",
|
||||||
|
"noise_seed": 0,
|
||||||
|
"steps": 4,
|
||||||
|
"cfg": 1,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"start_at_step": 2,
|
||||||
|
"end_at_step": 4,
|
||||||
|
"return_with_leftover_noise": "disable",
|
||||||
|
"model": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"17",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"17",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"11",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "KSampler (Advanced)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
|
||||||
|
"strength_model": 1.0000000000000002,
|
||||||
|
"model": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoraLoaderModelOnly"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"10": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
|
||||||
|
"strength_model": 1.0000000000000002,
|
||||||
|
"model": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoraLoaderModelOnly"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"11": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "enable",
|
||||||
|
"noise_seed": 626287185902791,
|
||||||
|
"steps": 4,
|
||||||
|
"cfg": 1,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"start_at_step": 0,
|
||||||
|
"end_at_step": 2,
|
||||||
|
"return_with_leftover_noise": "enable",
|
||||||
|
"model": [
|
||||||
|
"12",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"17",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"17",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"17",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "KSampler (Advanced)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"12": {
|
||||||
|
"inputs": {
|
||||||
|
"shift": 5.000000000000001,
|
||||||
|
"model": [
|
||||||
|
"14",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ModelSamplingSD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ModelSamplingSD3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"13": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "NSFW-22-L-e8.safetensors",
|
||||||
|
"strength_model": 1,
|
||||||
|
"model": [
|
||||||
|
"10",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoraLoaderModelOnly"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"14": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "NSFW-22-H-e8.safetensors",
|
||||||
|
"strength_model": 1,
|
||||||
|
"model": [
|
||||||
|
"9",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoraLoaderModelOnly"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"15": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"format": "auto",
|
||||||
|
"codec": "auto",
|
||||||
|
"video": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Video"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走, clothes",
|
||||||
|
"clip": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Negative Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"17": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 320,
|
||||||
|
"height": 640,
|
||||||
|
"length": 105,
|
||||||
|
"batch_size": 1,
|
||||||
|
"positive": [
|
||||||
|
"18",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"16",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"2",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"start_image": [
|
||||||
|
"19",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "WanImageToVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "WanImageToVideo"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"18": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "A girl undress her clothes, ((completely naked)), showing her ((breasts)), masturbating",
|
||||||
|
"clip": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Positive Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"19": {
|
||||||
|
"inputs": {
|
||||||
|
"image": "611667883_3429056790577893_8252297829424260857_n.jpg"
|
||||||
|
},
|
||||||
|
"class_type": "LoadImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"20": {
|
||||||
|
"inputs": {
|
||||||
|
"anything": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "easy cleanGpuUsed",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Clean VRAM Used"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1
web/__init__.py
Normal file
1
web/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# web package
|
||||||
269
web/app.py
Normal file
269
web/app.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""
|
||||||
|
web/app.py
|
||||||
|
==========
|
||||||
|
|
||||||
|
FastAPI application factory.
|
||||||
|
|
||||||
|
The app is created once and shared between the Uvicorn server (started
|
||||||
|
from bot.py via asyncio.gather) and tests.
|
||||||
|
|
||||||
|
Startup tasks:
|
||||||
|
- Background status ticker (broadcasts status_snapshot every 5s to all clients)
|
||||||
|
- Background NSSM poll (broadcasts server_state every 10s to all clients)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from starlette.exceptions import HTTPException as _HTTPException
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request as _Request
|
||||||
|
|
||||||
|
# Windows registry can map .js → text/plain; override to the correct types
|
||||||
|
# before StaticFiles reads them.
|
||||||
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
|
mimetypes.add_type("application/javascript", ".mjs")
|
||||||
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
mimetypes.add_type("application/wasm", ".wasm")
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
|
||||||
|
class _NoCacheHTMLMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Force browsers to revalidate index.html on every request.
|
||||||
|
|
||||||
|
Vite hashes JS/CSS filenames on every build so those assets are
|
||||||
|
naturally cache-busted. index.html itself has a stable name, so
|
||||||
|
without an explicit Cache-Control header mobile browsers apply
|
||||||
|
heuristic caching and keep serving a stale copy after a redeploy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: _Request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
ct = response.headers.get("content-type", "")
|
||||||
|
if "text/html" in ct:
|
||||||
|
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class _SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Add security headers to every response."""
|
||||||
|
|
||||||
|
_CSP = (
|
||||||
|
"default-src 'self'; "
|
||||||
|
"script-src 'self' 'unsafe-inline'; "
|
||||||
|
"style-src 'self' 'unsafe-inline'; "
|
||||||
|
"img-src 'self' data: blob:; "
|
||||||
|
"connect-src 'self' wss:; "
|
||||||
|
"frame-ancestors 'none';"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def dispatch(self, request: _Request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
response.headers["X-Frame-Options"] = "DENY"
|
||||||
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||||
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||||
|
response.headers["Content-Security-Policy"] = self._CSP
|
||||||
|
return response
|
||||||
|
|
||||||
|
class _SPAStaticFiles(StaticFiles):
|
||||||
|
"""StaticFiles with SPA fallback: serve index.html for unknown paths.
|
||||||
|
|
||||||
|
Starlette's html=True only serves index.html for directory requests.
|
||||||
|
This subclass additionally returns index.html for any path that has no
|
||||||
|
matching file, so client-side routes like /generate work on refresh.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_response(self, path: str, scope):
|
||||||
|
try:
|
||||||
|
return await super().get_response(path, scope)
|
||||||
|
except _HTTPException as ex:
|
||||||
|
if ex.status_code == 404:
|
||||||
|
return await super().get_response("index.html", scope)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
from web.ws_bus import get_bus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
_WEB_STATIC = _PROJECT_ROOT / "web-static"
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create and configure the FastAPI application."""
|
||||||
|
app = FastAPI(
|
||||||
|
title="ComfyUI Bot Web UI",
|
||||||
|
version="1.0.0",
|
||||||
|
docs_url=None,
|
||||||
|
redoc_url=None,
|
||||||
|
openapi_url=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CORS — only allow explicitly configured origins; empty = no cross-origin
|
||||||
|
_cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()]
|
||||||
|
if _cors_origins:
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=_cors_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Security headers on every response
|
||||||
|
app.add_middleware(_SecurityHeadersMiddleware)
|
||||||
|
|
||||||
|
# Prevent browsers from caching index.html across deploys
|
||||||
|
app.add_middleware(_NoCacheHTMLMiddleware)
|
||||||
|
|
||||||
|
# Register API routers
|
||||||
|
from web.routers.auth_router import router as auth_router
|
||||||
|
from web.routers.admin_router import router as admin_router
|
||||||
|
from web.routers.status_router import router as status_router
|
||||||
|
from web.routers.state_router import router as state_router
|
||||||
|
from web.routers.generate_router import router as generate_router
|
||||||
|
from web.routers.inputs_router import router as inputs_router
|
||||||
|
from web.routers.presets_router import router as presets_router
|
||||||
|
from web.routers.server_router import router as server_router
|
||||||
|
from web.routers.history_router import router as history_router
|
||||||
|
from web.routers.share_router import router as share_router
|
||||||
|
from web.routers.workflow_router import router as workflow_router
|
||||||
|
from web.routers.ws_router import router as ws_router
|
||||||
|
|
||||||
|
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
||||||
|
app.include_router(admin_router, prefix="/api/admin", tags=["admin"])
|
||||||
|
app.include_router(status_router, prefix="/api", tags=["status"])
|
||||||
|
app.include_router(state_router, prefix="/api", tags=["state"])
|
||||||
|
app.include_router(generate_router, prefix="/api", tags=["generate"])
|
||||||
|
app.include_router(inputs_router, prefix="/api/inputs", tags=["inputs"])
|
||||||
|
app.include_router(presets_router, prefix="/api/presets", tags=["presets"])
|
||||||
|
app.include_router(server_router, prefix="/api", tags=["server"])
|
||||||
|
app.include_router(history_router, prefix="/api/history", tags=["history"])
|
||||||
|
app.include_router(share_router, prefix="/api/share", tags=["share"])
|
||||||
|
app.include_router(workflow_router, prefix="/api/workflow", tags=["workflow"])
|
||||||
|
app.include_router(ws_router, tags=["ws"])
|
||||||
|
|
||||||
|
# Serve frontend static files (if built)
|
||||||
|
if _WEB_STATIC.exists() and any(_WEB_STATIC.iterdir()):
|
||||||
|
app.mount("/", _SPAStaticFiles(directory=str(_WEB_STATIC), html=True), name="static")
|
||||||
|
logger.info("Serving frontend from %s", _WEB_STATIC)
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def _startup():
|
||||||
|
asyncio.create_task(_status_ticker())
|
||||||
|
asyncio.create_task(_server_state_poller())
|
||||||
|
logger.info("Web background tasks started")
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Background tasks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _status_ticker() -> None:
|
||||||
|
"""Broadcast status_snapshot to all clients every 5 seconds."""
|
||||||
|
from web.deps import get_bot, get_comfy, get_config
|
||||||
|
bus = get_bus()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
try:
|
||||||
|
bot = get_bot()
|
||||||
|
comfy = get_comfy()
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
snapshot: dict = {}
|
||||||
|
|
||||||
|
if bot is not None:
|
||||||
|
lat = bot.latency
|
||||||
|
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
|
||||||
|
import datetime as _dt
|
||||||
|
start = getattr(bot, "start_time", None)
|
||||||
|
uptime = ""
|
||||||
|
if start:
|
||||||
|
delta = _dt.datetime.now(_dt.timezone.utc) - start
|
||||||
|
total = int(delta.total_seconds())
|
||||||
|
h, rem = divmod(total, 3600)
|
||||||
|
m, s = divmod(rem, 60)
|
||||||
|
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
|
||||||
|
snapshot["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
|
||||||
|
|
||||||
|
if comfy is not None:
|
||||||
|
q = await comfy.get_comfy_queue()
|
||||||
|
pending = len(q.get("queue_pending", [])) if q else 0
|
||||||
|
running = len(q.get("queue_running", [])) if q else 0
|
||||||
|
wm = getattr(comfy, "workflow_manager", None)
|
||||||
|
wf_loaded = wm is not None and wm.get_workflow_template() is not None
|
||||||
|
snapshot["comfy"] = {
|
||||||
|
"server": comfy.server_address,
|
||||||
|
"queue_pending": pending,
|
||||||
|
"queue_running": running,
|
||||||
|
"workflow_loaded": wf_loaded,
|
||||||
|
"last_seed": comfy.last_seed,
|
||||||
|
"total_generated": comfy.total_generated,
|
||||||
|
}
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
from media_uploader import get_stats as get_upload_stats, is_running as upload_running
|
||||||
|
try:
|
||||||
|
us = get_upload_stats()
|
||||||
|
snapshot["upload"] = {
|
||||||
|
"configured": bool(config.media_upload_user),
|
||||||
|
"running": upload_running(),
|
||||||
|
"total_ok": us.total_ok,
|
||||||
|
"total_fail": us.total_fail,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from web.deps import get_user_registry
|
||||||
|
registry = get_user_registry()
|
||||||
|
connected = bus.connected_users
|
||||||
|
if connected and registry:
|
||||||
|
for ul in connected:
|
||||||
|
user_overrides = registry.get_state_manager(ul).get_overrides()
|
||||||
|
await bus.broadcast_to_user(ul, "status_snapshot", {**snapshot, "overrides": user_overrides})
|
||||||
|
else:
|
||||||
|
await bus.broadcast("status_snapshot", snapshot)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Status ticker error: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _server_state_poller() -> None:
|
||||||
|
"""Poll NSSM service state and broadcast server_state every 10 seconds."""
|
||||||
|
from web.deps import get_config
|
||||||
|
bus = get_bus()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
try:
|
||||||
|
config = get_config()
|
||||||
|
if config is None:
|
||||||
|
continue
|
||||||
|
from commands.server import get_service_state
|
||||||
|
from web.deps import get_comfy
|
||||||
|
|
||||||
|
async def _false():
|
||||||
|
return False
|
||||||
|
|
||||||
|
comfy = get_comfy()
|
||||||
|
service_state, http_reachable = await asyncio.gather(
|
||||||
|
get_service_state(config.comfy_service_name),
|
||||||
|
comfy.check_connection() if comfy else _false(),
|
||||||
|
)
|
||||||
|
await bus.broadcast("server_state", {
|
||||||
|
"state": service_state,
|
||||||
|
"http_reachable": http_reachable,
|
||||||
|
})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Server state poller error: %s", exc)
|
||||||
119
web/auth.py
Normal file
119
web/auth.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""
|
||||||
|
web/auth.py
|
||||||
|
===========
|
||||||
|
|
||||||
|
JWT authentication for the web UI.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
- POST /api/auth/login {token} → verify invite token → issue JWT in httpOnly cookie
|
||||||
|
- All /api/* require valid JWT via require_auth dependency
|
||||||
|
- POST /api/admin/login {password} → issue admin JWT (admin: true claim)
|
||||||
|
- WS /ws?token=<jwt> → authenticate via query param
|
||||||
|
|
||||||
|
JWT claims: {"sub": "<label>", "admin": bool, "exp": ...}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Cookie, Depends, HTTPException, status
|
||||||
|
from fastapi.security import HTTPBearer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
except ImportError:
|
||||||
|
jwt = None # type: ignore
|
||||||
|
JWTError = Exception # type: ignore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ALGORITHM = "HS256"
|
||||||
|
_COOKIE_NAME = "ttb_session"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_secret() -> str:
|
||||||
|
from web.deps import get_config
|
||||||
|
cfg = get_config()
|
||||||
|
if cfg and cfg.web_secret_key:
|
||||||
|
return cfg.web_secret_key
|
||||||
|
raise RuntimeError(
|
||||||
|
"WEB_SECRET_KEY must be set in the environment — "
|
||||||
|
"refusing to run with an insecure default."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_jwt(label: str, *, admin: bool = False, expire_hours: int = 8) -> str:
|
||||||
|
"""Create a signed JWT for the given user label."""
|
||||||
|
if jwt is None:
|
||||||
|
raise RuntimeError("python-jose is not installed (pip install python-jose[cryptography])")
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(hours=expire_hours)
|
||||||
|
payload = {"sub": label, "admin": admin, "exp": expire}
|
||||||
|
return jwt.encode(payload, _get_secret(), algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_jwt(token: str) -> Optional[dict]:
|
||||||
|
"""Decode and verify a JWT. Returns the payload or None on failure."""
|
||||||
|
if jwt is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return jwt.decode(token, _get_secret(), algorithms=[ALGORITHM])
|
||||||
|
except JWTError as exc:
|
||||||
|
logger.debug("JWT decode failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_ws_token(token: str) -> Optional[dict]:
|
||||||
|
"""Verify a JWT passed as a WebSocket query parameter."""
|
||||||
|
return decode_jwt(token)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FastAPI dependencies
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def require_auth(ttb_session: Optional[str] = Cookie(default=None)) -> dict:
|
||||||
|
"""
|
||||||
|
FastAPI dependency that requires a valid JWT cookie.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
The decoded JWT payload (``sub``, ``admin`` fields).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
HTTPException 401
|
||||||
|
If the cookie is absent or the token is invalid/expired.
|
||||||
|
"""
|
||||||
|
if not ttb_session:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Not authenticated",
|
||||||
|
)
|
||||||
|
payload = decode_jwt(ttb_session)
|
||||||
|
if payload is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired token",
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def require_admin(user: dict = Depends(require_auth)) -> dict:
|
||||||
|
"""
|
||||||
|
FastAPI dependency that requires an admin JWT.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
HTTPException 403
|
||||||
|
If the token is valid but not admin.
|
||||||
|
"""
|
||||||
|
if not user.get("admin"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Admin access required",
|
||||||
|
)
|
||||||
|
return user
|
||||||
72
web/deps.py
Normal file
72
web/deps.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""
|
||||||
|
web/deps.py
|
||||||
|
===========
|
||||||
|
|
||||||
|
Shared bot reference for FastAPI dependency injection.
|
||||||
|
|
||||||
|
``set_bot()`` is called once from ``bot.py`` before starting Uvicorn.
|
||||||
|
FastAPI route handlers use ``get_bot()``, ``get_comfy()``, etc. as
|
||||||
|
Depends() callables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
_bot = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_bot(bot) -> None:
|
||||||
|
"""Store the discord.py bot instance for DI access."""
|
||||||
|
global _bot
|
||||||
|
_bot = bot
|
||||||
|
|
||||||
|
|
||||||
|
def get_bot():
|
||||||
|
"""FastAPI dependency: return the bot instance."""
|
||||||
|
return _bot
|
||||||
|
|
||||||
|
|
||||||
|
def get_comfy():
|
||||||
|
"""FastAPI dependency: return the ComfyClient."""
|
||||||
|
if _bot is None:
|
||||||
|
return None
|
||||||
|
return getattr(_bot, "comfy", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
"""FastAPI dependency: return the BotConfig."""
|
||||||
|
if _bot is None:
|
||||||
|
return None
|
||||||
|
return getattr(_bot, "config", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_manager():
|
||||||
|
"""FastAPI dependency: return the WorkflowStateManager."""
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
return None
|
||||||
|
return getattr(comfy, "state_manager", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_workflow_manager():
|
||||||
|
"""FastAPI dependency: return the WorkflowManager."""
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
return None
|
||||||
|
return getattr(comfy, "workflow_manager", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inspector():
|
||||||
|
"""FastAPI dependency: return the WorkflowInspector."""
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
return None
|
||||||
|
return getattr(comfy, "inspector", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_registry():
|
||||||
|
"""FastAPI dependency: return the UserStateRegistry."""
|
||||||
|
if _bot is None:
|
||||||
|
return None
|
||||||
|
return getattr(_bot, "user_registry", None)
|
||||||
146
web/login_guard.py
Normal file
146
web/login_guard.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
web/login_guard.py
|
||||||
|
==================
|
||||||
|
|
||||||
|
IP-based brute-force protection for login endpoints.
|
||||||
|
|
||||||
|
Tracks failed login attempts per IP in a rolling time window and issues a
|
||||||
|
temporary ban when the threshold is exceeded. Uses only stdlib — no new
|
||||||
|
pip packages required.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
from web.login_guard import get_guard, get_real_ip
|
||||||
|
|
||||||
|
@router.post("/login")
|
||||||
|
async def login(request: Request, body: LoginRequest, response: Response):
|
||||||
|
ip = get_real_ip(request)
|
||||||
|
get_guard().check(ip) # raises 429 if locked out
|
||||||
|
...
|
||||||
|
if failure:
|
||||||
|
get_guard().record_failure(ip)
|
||||||
|
raise HTTPException(401, ...)
|
||||||
|
get_guard().record_success(ip)
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, status
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_real_ip(request: Request) -> str:
|
||||||
|
"""Return the real client IP, honouring Cloudflare and common proxy headers.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. ``CF-Connecting-IP`` (set by Cloudflare)
|
||||||
|
2. ``X-Real-IP`` (set by nginx/traefik)
|
||||||
|
3. ``request.client.host`` (direct connection fallback)
|
||||||
|
"""
|
||||||
|
cf_ip = request.headers.get("CF-Connecting-IP", "").strip()
|
||||||
|
if cf_ip:
|
||||||
|
return cf_ip
|
||||||
|
real_ip = request.headers.get("X-Real-IP", "").strip()
|
||||||
|
if real_ip:
|
||||||
|
return real_ip
|
||||||
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class BruteForceGuard:
|
||||||
|
"""Rolling-window failure counter with automatic IP bans.
|
||||||
|
|
||||||
|
All state is in-process memory. A restart clears all bans and counters,
|
||||||
|
which is acceptable — a brief restart already provides a natural backoff
|
||||||
|
for a legitimate attacker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
WINDOW_SECS = 600 # rolling window: 10 minutes
|
||||||
|
MAX_FAILURES = 10 # max failures before ban
|
||||||
|
BAN_SECS = 3600 # ban duration: 1 hour
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# ip → list of failure timestamps (epoch floats)
|
||||||
|
self._failures: Dict[str, List[float]] = defaultdict(list)
|
||||||
|
# ip → ban expiry timestamp
|
||||||
|
self._ban_until: Dict[str, float] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def check(self, ip: str) -> None:
|
||||||
|
"""Raise HTTP 429 if the IP is banned or has exceeded the failure threshold.
|
||||||
|
|
||||||
|
Call this *before* doing any credential work so the lockout is
|
||||||
|
evaluated even when the request body is malformed.
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Active ban?
|
||||||
|
ban_expiry = self._ban_until.get(ip, 0)
|
||||||
|
if ban_expiry > now:
|
||||||
|
logger.warning("login_guard: blocked request from banned ip=%s (ban expires in %.0fs)", ip, ban_expiry - now)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail="Too many failed attempts. Try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Failure count within rolling window
|
||||||
|
cutoff = now - self.WINDOW_SECS
|
||||||
|
recent = [t for t in self._failures[ip] if t > cutoff]
|
||||||
|
self._failures[ip] = recent # prune stale entries while we're here
|
||||||
|
if len(recent) >= self.MAX_FAILURES:
|
||||||
|
# Threshold just reached — apply ban now
|
||||||
|
self._ban_until[ip] = now + self.BAN_SECS
|
||||||
|
logger.warning(
|
||||||
|
"login_guard: threshold reached, banning ip=%s for %ds",
|
||||||
|
ip, self.BAN_SECS,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail="Too many failed attempts. Try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def record_failure(self, ip: str) -> None:
|
||||||
|
"""Record a failed login attempt for the given IP."""
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - self.WINDOW_SECS
|
||||||
|
recent = [t for t in self._failures[ip] if t > cutoff]
|
||||||
|
recent.append(now)
|
||||||
|
self._failures[ip] = recent
|
||||||
|
count = len(recent)
|
||||||
|
logger.warning("login_guard: failure #%d from ip=%s", count, ip)
|
||||||
|
|
||||||
|
if count >= self.MAX_FAILURES:
|
||||||
|
self._ban_until[ip] = now + self.BAN_SECS
|
||||||
|
logger.warning(
|
||||||
|
"login_guard: threshold reached, banning ip=%s for %ds",
|
||||||
|
ip, self.BAN_SECS,
|
||||||
|
)
|
||||||
|
|
||||||
|
def record_success(self, ip: str) -> None:
|
||||||
|
"""Clear failure history and any active ban for the given IP."""
|
||||||
|
self._failures.pop(ip, None)
|
||||||
|
self._ban_until.pop(ip, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_guard: BruteForceGuard | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_guard() -> BruteForceGuard:
|
||||||
|
"""Return the shared BruteForceGuard singleton (created on first call)."""
|
||||||
|
global _guard
|
||||||
|
if _guard is None:
|
||||||
|
_guard = BruteForceGuard()
|
||||||
|
return _guard
|
||||||
1
web/routers/__init__.py
Normal file
1
web/routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# web.routers package
|
||||||
88
web/routers/admin_router.py
Normal file
88
web/routers/admin_router.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""POST /api/admin/login; GET/POST/DELETE /api/admin/tokens"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from web.auth import create_jwt, require_admin
|
||||||
|
from web.deps import get_config
|
||||||
|
from web.login_guard import get_guard, get_real_ip
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
_COOKIE = "ttb_session"
|
||||||
|
audit = logging.getLogger("audit")
|
||||||
|
|
||||||
|
|
||||||
|
class AdminLoginRequest(BaseModel):
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class CreateTokenRequest(BaseModel):
|
||||||
|
label: str
|
||||||
|
admin: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login")
|
||||||
|
async def admin_login(body: AdminLoginRequest, request: Request, response: Response):
|
||||||
|
"""Admin password login → admin JWT cookie."""
|
||||||
|
config = get_config()
|
||||||
|
expected_pw = config.admin_password if config else None
|
||||||
|
|
||||||
|
ip = get_real_ip(request)
|
||||||
|
get_guard().check(ip)
|
||||||
|
|
||||||
|
# Constant-time comparison to prevent timing attacks
|
||||||
|
if not expected_pw or not hmac.compare_digest(body.password, expected_pw):
|
||||||
|
get_guard().record_failure(ip)
|
||||||
|
audit.info("admin.login ip=%s success=False", ip)
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Wrong password")
|
||||||
|
|
||||||
|
get_guard().record_success(ip)
|
||||||
|
audit.info("admin.login ip=%s success=True", ip)
|
||||||
|
|
||||||
|
expire_hours = config.web_jwt_expire_hours if config else 8
|
||||||
|
jwt_token = create_jwt("admin", admin=True, expire_hours=expire_hours)
|
||||||
|
response.set_cookie(
|
||||||
|
_COOKIE, jwt_token,
|
||||||
|
httponly=True, secure=True, samesite="strict",
|
||||||
|
max_age=expire_hours * 3600,
|
||||||
|
)
|
||||||
|
return {"label": "admin", "admin": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tokens")
|
||||||
|
async def list_tokens(_: dict = Depends(require_admin)):
|
||||||
|
"""List all invite tokens (hashes shown, labels safe)."""
|
||||||
|
from token_store import list_tokens as _list
|
||||||
|
config = get_config()
|
||||||
|
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||||
|
records = _list(token_file)
|
||||||
|
# Don't return hashes to the UI
|
||||||
|
return [{"id": r["id"], "label": r["label"], "admin": r.get("admin", False),
|
||||||
|
"created_at": r.get("created_at")} for r in records]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tokens")
|
||||||
|
async def create_token(body: CreateTokenRequest, _: dict = Depends(require_admin)):
|
||||||
|
"""Create a new invite token. Returns the plaintext token (shown once)."""
|
||||||
|
from token_store import create_token as _create
|
||||||
|
config = get_config()
|
||||||
|
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||||
|
plaintext = _create(body.label, token_file, admin=body.admin)
|
||||||
|
return {"token": plaintext, "label": body.label, "admin": body.admin}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/tokens/{token_id}")
|
||||||
|
async def revoke_token(token_id: str, _: dict = Depends(require_admin)):
|
||||||
|
"""Revoke an invite token by ID."""
|
||||||
|
from token_store import revoke_token as _revoke
|
||||||
|
config = get_config()
|
||||||
|
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||||
|
ok = _revoke(token_id, token_file)
|
||||||
|
if not ok:
|
||||||
|
raise HTTPException(status_code=404, detail="Token not found")
|
||||||
|
return {"ok": True}
|
||||||
64
web/routers/auth_router.py
Normal file
64
web/routers/auth_router.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""POST /api/auth/login|logout; GET /api/auth/me"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from web.auth import create_jwt, require_auth
|
||||||
|
from web.deps import get_config
|
||||||
|
from web.login_guard import get_guard, get_real_ip
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
_COOKIE = "ttb_session"
|
||||||
|
audit = logging.getLogger("audit")
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login")
|
||||||
|
async def login(body: LoginRequest, request: Request, response: Response):
|
||||||
|
"""Exchange an invite token for a JWT session cookie."""
|
||||||
|
from token_store import verify_token
|
||||||
|
config = get_config()
|
||||||
|
token_file = config.web_token_file if config else "invite_tokens.json"
|
||||||
|
expire_hours = config.web_jwt_expire_hours if config else 8
|
||||||
|
|
||||||
|
ip = get_real_ip(request)
|
||||||
|
get_guard().check(ip)
|
||||||
|
|
||||||
|
record = verify_token(body.token, token_file)
|
||||||
|
if record is None:
|
||||||
|
get_guard().record_failure(ip)
|
||||||
|
audit.info("auth.login ip=%s success=False", ip)
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||||
|
|
||||||
|
label: str = record["label"]
|
||||||
|
admin: bool = record.get("admin", False)
|
||||||
|
get_guard().record_success(ip)
|
||||||
|
audit.info("auth.login ip=%s success=True label=%s", ip, label)
|
||||||
|
|
||||||
|
jwt_token = create_jwt(label, admin=admin, expire_hours=expire_hours)
|
||||||
|
response.set_cookie(
|
||||||
|
_COOKIE, jwt_token,
|
||||||
|
httponly=True, secure=True, samesite="strict",
|
||||||
|
max_age=expire_hours * 3600,
|
||||||
|
)
|
||||||
|
return {"label": label, "admin": admin}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout")
|
||||||
|
async def logout(response: Response):
|
||||||
|
"""Clear the session cookie."""
|
||||||
|
response.delete_cookie(_COOKIE)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me")
|
||||||
|
async def me(user: dict = Depends(require_auth)):
|
||||||
|
"""Return current user info."""
|
||||||
|
return {"label": user["sub"], "admin": user.get("admin", False)}
|
||||||
255
web/routers/generate_router.py
Normal file
255
web/routers/generate_router.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""POST /api/generate and /api/workflow-gen"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from web.auth import require_auth
|
||||||
|
from web.deps import get_comfy, get_config, get_user_registry
|
||||||
|
from web.ws_bus import get_bus
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
negative_prompt: Optional[str] = None
|
||||||
|
overrides: Optional[Dict[str, Any]] = None # extra per-request overrides
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowGenRequest(BaseModel):
|
||||||
|
count: int = 1
|
||||||
|
overrides: Optional[Dict[str, Any]] = None # per-request overrides (merged with state)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate")
|
||||||
|
async def generate(body: GenerateRequest, user: dict = Depends(require_auth)):
|
||||||
|
"""Submit a prompt-based generation to ComfyUI."""
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
raise HTTPException(503, "ComfyUI client not available")
|
||||||
|
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
bus = get_bus()
|
||||||
|
registry = get_user_registry()
|
||||||
|
|
||||||
|
# Temporary seed override from request
|
||||||
|
if body.overrides and "seed" in body.overrides:
|
||||||
|
seed_override = body.overrides["seed"]
|
||||||
|
elif registry:
|
||||||
|
seed_override = registry.get_state_manager(user_label).get_seed()
|
||||||
|
else:
|
||||||
|
seed_override = comfy.state_manager.get_seed()
|
||||||
|
|
||||||
|
overrides_for_gen = {"prompt": body.prompt}
|
||||||
|
if body.negative_prompt:
|
||||||
|
overrides_for_gen["negative_prompt"] = body.negative_prompt
|
||||||
|
if seed_override is not None:
|
||||||
|
overrides_for_gen["seed"] = seed_override
|
||||||
|
|
||||||
|
# Also apply any extra per-request overrides
|
||||||
|
if body.overrides:
|
||||||
|
overrides_for_gen.update(body.overrides)
|
||||||
|
|
||||||
|
# Get queue position estimate
|
||||||
|
depth = await comfy.get_queue_depth()
|
||||||
|
|
||||||
|
# Start generation as background task so we can return the prompt_id immediately
|
||||||
|
prompt_id_holder: list = []
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
# Use the user's own workflow template
|
||||||
|
if registry:
|
||||||
|
template = registry.get_workflow_template(user_label)
|
||||||
|
else:
|
||||||
|
template = comfy.workflow_manager.get_workflow_template()
|
||||||
|
if not template:
|
||||||
|
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||||
|
"prompt_id": None, "error": "No workflow template loaded"
|
||||||
|
})
|
||||||
|
return
|
||||||
|
import uuid
|
||||||
|
pid = str(uuid.uuid4())
|
||||||
|
prompt_id_holder.append(pid)
|
||||||
|
|
||||||
|
def on_progress(node, pid_):
|
||||||
|
asyncio.create_task(bus.broadcast("node_executing", {
|
||||||
|
"node": node, "prompt_id": pid_
|
||||||
|
}))
|
||||||
|
|
||||||
|
workflow, applied = comfy.inspector.inject_overrides(template, overrides_for_gen)
|
||||||
|
seed_used = applied.get("seed")
|
||||||
|
comfy.last_seed = seed_used
|
||||||
|
|
||||||
|
try:
|
||||||
|
images, videos = await comfy._general_generate(workflow, pid, on_progress)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Generation error for prompt %s", pid)
|
||||||
|
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||||
|
"prompt_id": pid, "error": str(exc)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
comfy.last_prompt_id = pid
|
||||||
|
comfy.total_generated += 1
|
||||||
|
|
||||||
|
# Persist to DB before flush_pending deletes local files
|
||||||
|
config = get_config()
|
||||||
|
try:
|
||||||
|
from generation_db import record_generation, record_file
|
||||||
|
gen_id = record_generation(pid, "web", user_label, overrides_for_gen, seed_used)
|
||||||
|
for i, img_data in enumerate(images):
|
||||||
|
record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||||
|
if config and videos:
|
||||||
|
for vid in videos:
|
||||||
|
vsub = vid.get("video_subfolder", "")
|
||||||
|
vname = vid.get("video_name", "")
|
||||||
|
vpath = (
|
||||||
|
Path(config.comfy_output_path) / vsub / vname
|
||||||
|
if vsub
|
||||||
|
else Path(config.comfy_output_path) / vname
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
record_file(gen_id, vname, vpath.read_bytes())
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to record generation to DB: %s", exc)
|
||||||
|
|
||||||
|
# Flush auto-upload
|
||||||
|
if config:
|
||||||
|
from media_uploader import flush_pending
|
||||||
|
asyncio.create_task(flush_pending(
|
||||||
|
Path(config.comfy_output_path),
|
||||||
|
config.media_upload_user,
|
||||||
|
config.media_upload_pass,
|
||||||
|
))
|
||||||
|
|
||||||
|
await bus.broadcast("queue_update", {
|
||||||
|
"prompt_id": pid,
|
||||||
|
"status": "complete",
|
||||||
|
})
|
||||||
|
await bus.broadcast_to_user(user_label, "generation_complete", {
|
||||||
|
"prompt_id": pid,
|
||||||
|
"seed": seed_used,
|
||||||
|
"image_count": len(images),
|
||||||
|
"video_count": len(videos),
|
||||||
|
})
|
||||||
|
|
||||||
|
asyncio.create_task(_run())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"queued": True,
|
||||||
|
"queue_position": depth + 1,
|
||||||
|
"message": "Generation submitted to ComfyUI",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/workflow-gen")
|
||||||
|
async def workflow_gen(body: WorkflowGenRequest, user: dict = Depends(require_auth)):
|
||||||
|
"""Submit workflow-based generation(s) to ComfyUI."""
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
raise HTTPException(503, "ComfyUI client not available")
|
||||||
|
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
bus = get_bus()
|
||||||
|
registry = get_user_registry()
|
||||||
|
count = max(1, min(body.count, 20)) # cap at 20
|
||||||
|
|
||||||
|
async def _run_one():
|
||||||
|
# Use the user's own state and template
|
||||||
|
if registry:
|
||||||
|
user_sm = registry.get_state_manager(user_label)
|
||||||
|
user_template = registry.get_workflow_template(user_label)
|
||||||
|
else:
|
||||||
|
user_sm = comfy.state_manager
|
||||||
|
user_template = comfy.workflow_manager.get_workflow_template()
|
||||||
|
|
||||||
|
if not user_template:
|
||||||
|
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||||
|
"prompt_id": None, "error": "No workflow template loaded"
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
overrides = user_sm.get_overrides()
|
||||||
|
if body.overrides:
|
||||||
|
overrides = {**overrides, **body.overrides}
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
pid = str(uuid.uuid4())
|
||||||
|
|
||||||
|
def on_progress(node, pid_):
|
||||||
|
asyncio.create_task(bus.broadcast("node_executing", {
|
||||||
|
"node": node, "prompt_id": pid_
|
||||||
|
}))
|
||||||
|
|
||||||
|
workflow, applied = comfy.inspector.inject_overrides(user_template, overrides)
|
||||||
|
seed_used = applied.get("seed")
|
||||||
|
comfy.last_seed = seed_used
|
||||||
|
|
||||||
|
try:
|
||||||
|
images, videos = await comfy._general_generate(workflow, pid, on_progress)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Workflow gen error")
|
||||||
|
await bus.broadcast_to_user(user_label, "generation_error", {
|
||||||
|
"prompt_id": None, "error": str(exc)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
comfy.last_prompt_id = pid
|
||||||
|
comfy.total_generated += 1
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
try:
|
||||||
|
from generation_db import record_generation, record_file
|
||||||
|
gen_id = record_generation(pid, "web", user_label, overrides, seed_used)
|
||||||
|
for i, img_data in enumerate(images):
|
||||||
|
record_file(gen_id, f"image_{i:04d}.png", img_data)
|
||||||
|
if config and videos:
|
||||||
|
for vid in videos:
|
||||||
|
vsub = vid.get("video_subfolder", "")
|
||||||
|
vname = vid.get("video_name", "")
|
||||||
|
vpath = (
|
||||||
|
Path(config.comfy_output_path) / vsub / vname
|
||||||
|
if vsub
|
||||||
|
else Path(config.comfy_output_path) / vname
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
record_file(gen_id, vname, vpath.read_bytes())
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to record generation to DB: %s", exc)
|
||||||
|
|
||||||
|
if config:
|
||||||
|
from media_uploader import flush_pending
|
||||||
|
asyncio.create_task(flush_pending(
|
||||||
|
Path(config.comfy_output_path),
|
||||||
|
config.media_upload_user,
|
||||||
|
config.media_upload_pass,
|
||||||
|
))
|
||||||
|
|
||||||
|
await bus.broadcast("queue_update", {"prompt_id": pid, "status": "complete"})
|
||||||
|
await bus.broadcast_to_user(user_label, "generation_complete", {
|
||||||
|
"prompt_id": pid,
|
||||||
|
"seed": seed_used,
|
||||||
|
"image_count": len(images),
|
||||||
|
"video_count": len(videos),
|
||||||
|
})
|
||||||
|
|
||||||
|
depth = await comfy.get_queue_depth()
|
||||||
|
for _ in range(count):
|
||||||
|
asyncio.create_task(_run_one())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"queued": True,
|
||||||
|
"count": count,
|
||||||
|
"queue_position": depth + 1,
|
||||||
|
}
|
||||||
143
web/routers/history_router.py
Normal file
143
web/routers/history_router.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""GET /api/history; GET /api/history/{prompt_id}/images; GET /api/history/{prompt_id}/file/{filename};
|
||||||
|
POST /api/history/{prompt_id}/share; DELETE /api/history/{prompt_id}/share"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from web.auth import require_auth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_owns(prompt_id: str, user: dict) -> None:
|
||||||
|
"""Raise 404 if the generation doesn't exist or doesn't belong to the user.
|
||||||
|
|
||||||
|
Returning the same 404 for both cases prevents leaking whether a
|
||||||
|
prompt_id exists to users who don't own it. Admins bypass this check.
|
||||||
|
"""
|
||||||
|
if user.get("admin"):
|
||||||
|
return
|
||||||
|
from generation_db import get_generation
|
||||||
|
gen = get_generation(prompt_id)
|
||||||
|
if gen is None or gen["user_label"] != user["sub"]:
|
||||||
|
raise HTTPException(404, "Not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def get_history(
|
||||||
|
user: dict = Depends(require_auth),
|
||||||
|
q: Optional[str] = Query(None, description="Keyword to search in overrides JSON"),
|
||||||
|
):
|
||||||
|
"""Return generation history. Admins see all; regular users see only their own.
|
||||||
|
Pass ?q=keyword to filter by prompt text or any override field."""
|
||||||
|
from generation_db import (
|
||||||
|
get_history as db_get_history,
|
||||||
|
get_history_for_user,
|
||||||
|
search_history,
|
||||||
|
search_history_for_user,
|
||||||
|
)
|
||||||
|
if q and q.strip():
|
||||||
|
if user.get("admin"):
|
||||||
|
return {"history": search_history(q.strip(), limit=50)}
|
||||||
|
return {"history": search_history_for_user(user["sub"], q.strip(), limit=50)}
|
||||||
|
if user.get("admin"):
|
||||||
|
return {"history": db_get_history(limit=50)}
|
||||||
|
return {"history": get_history_for_user(user["sub"], limit=50)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{prompt_id}/images")
|
||||||
|
async def get_history_images(prompt_id: str, user: dict = Depends(require_auth)):
|
||||||
|
"""
|
||||||
|
Fetch output files for a past generation.
|
||||||
|
|
||||||
|
Returns base64-encoded blobs from the local SQLite DB — works even after
|
||||||
|
``flush_pending`` has deleted the files from disk.
|
||||||
|
"""
|
||||||
|
_assert_owns(prompt_id, user)
|
||||||
|
from generation_db import get_files
|
||||||
|
files = get_files(prompt_id)
|
||||||
|
if not files:
|
||||||
|
raise HTTPException(404, f"No files found for prompt_id {prompt_id!r}")
|
||||||
|
return {
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"filename": f["filename"],
|
||||||
|
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
|
||||||
|
"mime_type": f["mime_type"],
|
||||||
|
}
|
||||||
|
for f in files
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{prompt_id}/file/{filename}")
|
||||||
|
async def get_history_file(
|
||||||
|
prompt_id: str,
|
||||||
|
filename: str,
|
||||||
|
request: Request,
|
||||||
|
user: dict = Depends(require_auth),
|
||||||
|
):
|
||||||
|
"""Stream a single output file, with HTTP range request support for video seeking."""
|
||||||
|
_assert_owns(prompt_id, user)
|
||||||
|
from generation_db import get_files
|
||||||
|
files = get_files(prompt_id)
|
||||||
|
matched = next((f for f in files if f["filename"] == filename), None)
|
||||||
|
if matched is None:
|
||||||
|
raise HTTPException(404, f"File {filename!r} not found for prompt_id {prompt_id!r}")
|
||||||
|
|
||||||
|
data: bytes = matched["data"]
|
||||||
|
mime: str = matched["mime_type"]
|
||||||
|
total = len(data)
|
||||||
|
|
||||||
|
range_header = request.headers.get("range")
|
||||||
|
if range_header:
|
||||||
|
range_val = range_header.replace("bytes=", "")
|
||||||
|
start_str, _, end_str = range_val.partition("-")
|
||||||
|
start = int(start_str) if start_str else 0
|
||||||
|
end = int(end_str) if end_str else total - 1
|
||||||
|
end = min(end, total - 1)
|
||||||
|
chunk = data[start : end + 1]
|
||||||
|
return Response(
|
||||||
|
content=chunk,
|
||||||
|
status_code=206,
|
||||||
|
media_type=mime,
|
||||||
|
headers={
|
||||||
|
"Content-Range": f"bytes {start}-{end}/{total}",
|
||||||
|
"Accept-Ranges": "bytes",
|
||||||
|
"Content-Length": str(len(chunk)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=data,
|
||||||
|
media_type=mime,
|
||||||
|
headers={
|
||||||
|
"Accept-Ranges": "bytes",
|
||||||
|
"Content-Length": str(total),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{prompt_id}/share")
|
||||||
|
async def create_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
|
||||||
|
"""Create a share token for a generation. Only the owner may share."""
|
||||||
|
# Use the same 404-for-everything helper to avoid leaking prompt_id existence
|
||||||
|
_assert_owns(prompt_id, user)
|
||||||
|
from generation_db import create_share as db_create_share
|
||||||
|
token = db_create_share(prompt_id, user["sub"])
|
||||||
|
return {"share_token": token}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{prompt_id}/share")
|
||||||
|
async def revoke_generation_share(prompt_id: str, user: dict = Depends(require_auth)):
|
||||||
|
"""Revoke a share token for a generation. Only the owner may revoke."""
|
||||||
|
from generation_db import revoke_share as db_revoke_share
|
||||||
|
deleted = db_revoke_share(prompt_id, user["sub"])
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(404, "No active share found for this generation")
|
||||||
|
return {"ok": True}
|
||||||
194
web/routers/inputs_router.py
Normal file
194
web/routers/inputs_router.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
"""GET/POST/DELETE /api/inputs; GET /api/inputs/{id}/image; POST /api/inputs/{id}/activate"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, UploadFile
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from web.auth import require_auth
|
||||||
|
from web.deps import get_config, get_user_registry
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_inputs(_: dict = Depends(require_auth)):
|
||||||
|
"""List all input images (Discord + web uploads)."""
|
||||||
|
from input_image_db import get_all_images
|
||||||
|
rows = get_all_images()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def upload_input(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
slot_key: Optional[str] = Form(default=None),
|
||||||
|
user: dict = Depends(require_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Upload an input image.
|
||||||
|
|
||||||
|
Stores image bytes directly in SQLite. If *slot_key* is provided the
|
||||||
|
image is immediately activated for that slot (writes to ComfyUI input
|
||||||
|
folder and updates the user's state override).
|
||||||
|
|
||||||
|
The physical slot file uses a namespaced key ``<user_label>_<slot_key>``
|
||||||
|
so concurrent users each get their own active image file.
|
||||||
|
"""
|
||||||
|
config = get_config()
|
||||||
|
if config is None:
|
||||||
|
raise HTTPException(503, "Config not available")
|
||||||
|
|
||||||
|
data = await file.read()
|
||||||
|
filename = file.filename or "upload.png"
|
||||||
|
|
||||||
|
from input_image_db import upsert_image, activate_image_for_slot
|
||||||
|
row_id = upsert_image(
|
||||||
|
original_message_id=0, # sentinel for web uploads
|
||||||
|
bot_reply_id=0,
|
||||||
|
channel_id=0,
|
||||||
|
filename=filename,
|
||||||
|
image_data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
activated_filename: str | None = None
|
||||||
|
if slot_key:
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
namespaced_key = f"{user_label}_{slot_key}"
|
||||||
|
activated_filename = activate_image_for_slot(
|
||||||
|
row_id, namespaced_key, config.comfy_input_path
|
||||||
|
)
|
||||||
|
registry = get_user_registry()
|
||||||
|
if registry:
|
||||||
|
registry.get_state_manager(user_label).set_override(slot_key, activated_filename)
|
||||||
|
else:
|
||||||
|
from web.deps import get_comfy
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy:
|
||||||
|
comfy.state_manager.set_override(slot_key, activated_filename)
|
||||||
|
|
||||||
|
return {"id": row_id, "filename": filename, "slot_key": slot_key, "activated_filename": activated_filename}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{row_id}/activate")
|
||||||
|
async def activate_input(
|
||||||
|
row_id: int,
|
||||||
|
slot_key: str = Body(default="input_image", embed=True),
|
||||||
|
user: dict = Depends(require_auth),
|
||||||
|
):
|
||||||
|
"""Write the stored image to the ComfyUI input folder and set the user's slot override."""
|
||||||
|
config = get_config()
|
||||||
|
if config is None:
|
||||||
|
raise HTTPException(503, "Config not available")
|
||||||
|
|
||||||
|
from input_image_db import get_image, activate_image_for_slot
|
||||||
|
row = get_image(row_id)
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(404, "Image not found")
|
||||||
|
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
namespaced_key = f"{user_label}_{slot_key}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
filename = activate_image_for_slot(row_id, namespaced_key, config.comfy_input_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(409, str(exc))
|
||||||
|
|
||||||
|
registry = get_user_registry()
|
||||||
|
if registry:
|
||||||
|
registry.get_state_manager(user_label).set_override(slot_key, filename)
|
||||||
|
else:
|
||||||
|
from web.deps import get_comfy
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
raise HTTPException(503, "State manager not available")
|
||||||
|
comfy.state_manager.set_override(slot_key, filename)
|
||||||
|
|
||||||
|
return {"ok": True, "slot_key": slot_key, "filename": filename}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{row_id}")
|
||||||
|
async def delete_input(row_id: int, _: dict = Depends(require_auth)):
|
||||||
|
"""Delete an input image record (and its active slot file if applicable)."""
|
||||||
|
from input_image_db import get_image, delete_image
|
||||||
|
row = get_image(row_id)
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(404, "Image not found")
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
delete_image(row_id, comfy_input_path=config.comfy_input_path if config else None)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{row_id}/image")
|
||||||
|
async def get_input_image(row_id: int, _: dict = Depends(require_auth)):
|
||||||
|
"""Serve the raw image bytes stored in the database for a given input image row."""
|
||||||
|
from input_image_db import get_image, get_image_data
|
||||||
|
row = get_image(row_id)
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(404, "Image not found")
|
||||||
|
|
||||||
|
data = get_image_data(row_id)
|
||||||
|
if data is None:
|
||||||
|
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
||||||
|
|
||||||
|
mime, _ = mimetypes.guess_type(row["filename"])
|
||||||
|
return Response(content=data, media_type=mime or "application/octet-stream")
|
||||||
|
|
||||||
|
|
||||||
|
def _pil_resize_response(data: bytes, filename: str, max_size: int, quality: int) -> Response:
|
||||||
|
"""Resize image bytes with Pillow and return a JPEG Response. Raises on failure."""
|
||||||
|
import io
|
||||||
|
from PIL import Image as _PIL
|
||||||
|
img = _PIL.open(io.BytesIO(data))
|
||||||
|
img.thumbnail((max_size, max_size), _PIL.LANCZOS)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.convert("RGB").save(buf, "JPEG", quality=quality, optimize=True)
|
||||||
|
return Response(
|
||||||
|
content=buf.getvalue(),
|
||||||
|
media_type="image/jpeg",
|
||||||
|
headers={"Cache-Control": "public, max-age=86400"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{row_id}/thumb")
|
||||||
|
async def get_input_thumb(row_id: int, _: dict = Depends(require_auth)):
|
||||||
|
"""Serve a small compressed thumbnail (max 200 px, JPEG 65 %) for fast previews."""
|
||||||
|
from input_image_db import get_image, get_image_data
|
||||||
|
row = get_image(row_id)
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(404, "Image not found")
|
||||||
|
|
||||||
|
data = get_image_data(row_id)
|
||||||
|
if data is None:
|
||||||
|
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _pil_resize_response(data, row["filename"], max_size=200, quality=65)
|
||||||
|
except Exception:
|
||||||
|
mime, _ = mimetypes.guess_type(row["filename"])
|
||||||
|
return Response(content=data, media_type=mime or "application/octet-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{row_id}/mid")
|
||||||
|
async def get_input_mid(row_id: int, _: dict = Depends(require_auth)):
|
||||||
|
"""Serve a medium compressed image (max 800 px, JPEG 80 %) for progressive loading."""
|
||||||
|
from input_image_db import get_image, get_image_data
|
||||||
|
row = get_image(row_id)
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(404, "Image not found")
|
||||||
|
|
||||||
|
data = get_image_data(row_id)
|
||||||
|
if data is None:
|
||||||
|
raise HTTPException(404, "Image data not available — re-upload to backfill")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _pil_resize_response(data, row["filename"], max_size=800, quality=80)
|
||||||
|
except Exception:
|
||||||
|
mime, _ = mimetypes.guess_type(row["filename"])
|
||||||
|
return Response(content=data, media_type=mime or "application/octet-stream")
|
||||||
153
web/routers/presets_router.py
Normal file
153
web/routers/presets_router.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""CRUD for workflow presets via /api/presets"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from web.auth import require_auth
|
||||||
|
from web.deps import get_comfy, get_user_registry
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class SavePresetRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SaveFromHistoryRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pm():
|
||||||
|
from web.deps import get_bot
|
||||||
|
bot = get_bot()
|
||||||
|
pm = getattr(bot, "preset_manager", None) if bot else None
|
||||||
|
if pm is None:
|
||||||
|
from preset_manager import PresetManager
|
||||||
|
pm = PresetManager()
|
||||||
|
return pm
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_presets(_: dict = Depends(require_auth)):
|
||||||
|
pm = _get_pm()
|
||||||
|
return {"presets": pm.list_preset_details()}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def save_preset(body: SavePresetRequest, user: dict = Depends(require_auth)):
|
||||||
|
"""Capture the user's overrides + workflow template as a named preset."""
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
registry = get_user_registry()
|
||||||
|
pm = _get_pm()
|
||||||
|
|
||||||
|
if registry:
|
||||||
|
workflow_template = registry.get_workflow_template(user_label)
|
||||||
|
overrides = registry.get_state_manager(user_label).get_overrides()
|
||||||
|
else:
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
raise HTTPException(503, "ComfyUI not available")
|
||||||
|
workflow_template = comfy.get_workflow_template()
|
||||||
|
overrides = comfy.state_manager.get_overrides()
|
||||||
|
|
||||||
|
try:
|
||||||
|
pm.save(body.name, workflow_template, overrides, owner=user_label, description=body.description)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(400, str(exc))
|
||||||
|
return {"ok": True, "name": body.name}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{name}")
|
||||||
|
async def get_preset(name: str, _: dict = Depends(require_auth)):
|
||||||
|
pm = _get_pm()
|
||||||
|
data = pm.load(name)
|
||||||
|
if data is None:
|
||||||
|
raise HTTPException(404, "Preset not found")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{name}/load")
|
||||||
|
async def load_preset(name: str, user: dict = Depends(require_auth)):
|
||||||
|
"""Restore overrides (and optionally workflow template) from a preset into the user's state."""
|
||||||
|
pm = _get_pm()
|
||||||
|
data = pm.load(name)
|
||||||
|
if data is None:
|
||||||
|
raise HTTPException(404, "Preset not found")
|
||||||
|
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
registry = get_user_registry()
|
||||||
|
|
||||||
|
if registry:
|
||||||
|
wf = data.get("workflow")
|
||||||
|
if wf:
|
||||||
|
registry.set_workflow(user_label, wf, name)
|
||||||
|
else:
|
||||||
|
# No workflow in preset — just clear overrides and restore state
|
||||||
|
registry.get_state_manager(user_label).clear_overrides()
|
||||||
|
state = data.get("state", {})
|
||||||
|
sm = registry.get_state_manager(user_label)
|
||||||
|
for k, v in state.items():
|
||||||
|
if v is not None:
|
||||||
|
sm.set_override(k, v)
|
||||||
|
else:
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
raise HTTPException(503, "ComfyUI not available")
|
||||||
|
comfy.state_manager.clear_overrides()
|
||||||
|
state = data.get("state", {})
|
||||||
|
for k, v in state.items():
|
||||||
|
if v is not None:
|
||||||
|
comfy.state_manager.set_override(k, v)
|
||||||
|
wf = data.get("workflow")
|
||||||
|
if wf:
|
||||||
|
comfy.workflow_manager.set_workflow_template(wf)
|
||||||
|
|
||||||
|
return {"ok": True, "name": name, "overrides_restored": list(data.get("state", {}).keys())}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{name}")
|
||||||
|
async def delete_preset(name: str, user: dict = Depends(require_auth)):
|
||||||
|
pm = _get_pm()
|
||||||
|
data = pm.load(name)
|
||||||
|
if data is None:
|
||||||
|
raise HTTPException(404, "Preset not found")
|
||||||
|
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
is_admin = user.get("admin") is True
|
||||||
|
owner = data.get("owner")
|
||||||
|
if owner is not None and owner != user_label and not is_admin:
|
||||||
|
raise HTTPException(403, "You do not have permission to delete this preset")
|
||||||
|
|
||||||
|
pm.delete(name)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/from-history/{prompt_id}")
|
||||||
|
async def save_preset_from_history(
|
||||||
|
prompt_id: str,
|
||||||
|
body: SaveFromHistoryRequest,
|
||||||
|
user: dict = Depends(require_auth),
|
||||||
|
):
|
||||||
|
"""Create a preset from a past generation's overrides."""
|
||||||
|
from generation_db import get_generation_full
|
||||||
|
|
||||||
|
gen = get_generation_full(prompt_id)
|
||||||
|
if gen is None:
|
||||||
|
raise HTTPException(404, "Generation not found")
|
||||||
|
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
is_admin = user.get("admin") is True
|
||||||
|
if not is_admin and gen.get("user_label") != user_label:
|
||||||
|
raise HTTPException(404, "Generation not found")
|
||||||
|
|
||||||
|
pm = _get_pm()
|
||||||
|
try:
|
||||||
|
pm.save(body.name, None, gen["overrides"], owner=user_label, description=body.description)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(400, str(exc))
|
||||||
|
return {"ok": True, "name": body.name}
|
||||||
90
web/routers/server_router.py
Normal file
90
web/routers/server_router.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""GET/POST /api/server/{action}; GET /api/logs/tail"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
|
from web.auth import require_auth
|
||||||
|
from web.deps import get_config, get_comfy
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/server/status")
|
||||||
|
async def server_status(_: dict = Depends(require_auth)):
|
||||||
|
"""Return NSSM service state and HTTP health."""
|
||||||
|
config = get_config()
|
||||||
|
if config is None:
|
||||||
|
raise HTTPException(503, "Config not available")
|
||||||
|
from commands.server import get_service_state
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def _false():
|
||||||
|
return False
|
||||||
|
|
||||||
|
comfy = get_comfy()
|
||||||
|
service_state, http_ok = await asyncio.gather(
|
||||||
|
get_service_state(config.comfy_service_name),
|
||||||
|
comfy.check_connection() if comfy else _false(),
|
||||||
|
)
|
||||||
|
return {"service_state": service_state, "http_reachable": http_ok}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/server/{action}")
|
||||||
|
async def server_action(action: str, _: dict = Depends(require_auth)):
|
||||||
|
"""Control the ComfyUI service: start | stop | restart | install | uninstall"""
|
||||||
|
config = get_config()
|
||||||
|
if config is None:
|
||||||
|
raise HTTPException(503, "Config not available")
|
||||||
|
|
||||||
|
valid_actions = {"start", "stop", "restart", "install", "uninstall"}
|
||||||
|
if action not in valid_actions:
|
||||||
|
raise HTTPException(400, f"Invalid action '{action}'")
|
||||||
|
|
||||||
|
from commands.server import _nssm, _install_service
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
try:
|
||||||
|
if action == "install":
|
||||||
|
ok, msg = await _install_service(config)
|
||||||
|
if not ok:
|
||||||
|
raise HTTPException(500, msg)
|
||||||
|
elif action == "uninstall":
|
||||||
|
await _nssm("stop", config.comfy_service_name)
|
||||||
|
await _nssm("remove", config.comfy_service_name, "confirm")
|
||||||
|
elif action == "start":
|
||||||
|
await _nssm("start", config.comfy_service_name)
|
||||||
|
elif action == "stop":
|
||||||
|
await _nssm("stop", config.comfy_service_name)
|
||||||
|
elif action == "restart":
|
||||||
|
await _nssm("restart", config.comfy_service_name)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(500, str(exc))
|
||||||
|
|
||||||
|
return {"ok": True, "action": action}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/logs/tail")
|
||||||
|
async def tail_logs(lines: int = 100, _: dict = Depends(require_auth)):
|
||||||
|
"""Tail the ComfyUI log file."""
|
||||||
|
config = get_config()
|
||||||
|
if config is None or not config.comfy_log_dir:
|
||||||
|
raise HTTPException(503, "Log directory not configured")
|
||||||
|
|
||||||
|
log_dir = Path(config.comfy_log_dir)
|
||||||
|
log_file = log_dir / "comfyui.log"
|
||||||
|
if not log_file.exists():
|
||||||
|
return {"lines": []}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
|
||||||
|
all_lines = f.readlines()
|
||||||
|
tail = all_lines[-min(lines, len(all_lines)):]
|
||||||
|
return {"lines": [ln.rstrip("\n") for ln in tail]}
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(500, str(exc))
|
||||||
85
web/routers/share_router.py
Normal file
85
web/routers/share_router.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""GET /api/share/{token}; GET /api/share/{token}/file/{filename}"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from web.auth import require_auth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{token}")
|
||||||
|
async def get_share(token: str, _: dict = Depends(require_auth)):
|
||||||
|
"""Fetch share metadata and images. Any authenticated user may view a valid share link."""
|
||||||
|
from generation_db import get_share_by_token, get_files
|
||||||
|
gen = get_share_by_token(token)
|
||||||
|
if gen is None:
|
||||||
|
raise HTTPException(404, "Share not found or revoked")
|
||||||
|
files = get_files(gen["prompt_id"])
|
||||||
|
return {
|
||||||
|
"prompt_id": gen["prompt_id"],
|
||||||
|
"created_at": gen["created_at"],
|
||||||
|
"overrides": gen["overrides"],
|
||||||
|
"seed": gen["seed"],
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"filename": f["filename"],
|
||||||
|
"data": base64.b64encode(f["data"]).decode() if not f["mime_type"].startswith("video/") else None,
|
||||||
|
"mime_type": f["mime_type"],
|
||||||
|
}
|
||||||
|
for f in files
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{token}/file/{filename}")
|
||||||
|
async def get_share_file(
|
||||||
|
token: str,
|
||||||
|
filename: str,
|
||||||
|
request: Request,
|
||||||
|
_: dict = Depends(require_auth),
|
||||||
|
):
|
||||||
|
"""Stream a single output file via share token, with HTTP range support for video seeking."""
|
||||||
|
from generation_db import get_share_by_token, get_files
|
||||||
|
gen = get_share_by_token(token)
|
||||||
|
if gen is None:
|
||||||
|
raise HTTPException(404, "Share not found or revoked")
|
||||||
|
files = get_files(gen["prompt_id"])
|
||||||
|
matched = next((f for f in files if f["filename"] == filename), None)
|
||||||
|
if matched is None:
|
||||||
|
raise HTTPException(404, f"File {filename!r} not found")
|
||||||
|
|
||||||
|
data: bytes = matched["data"]
|
||||||
|
mime: str = matched["mime_type"]
|
||||||
|
total = len(data)
|
||||||
|
|
||||||
|
range_header = request.headers.get("range")
|
||||||
|
if range_header:
|
||||||
|
range_val = range_header.replace("bytes=", "")
|
||||||
|
start_str, _, end_str = range_val.partition("-")
|
||||||
|
start = int(start_str) if start_str else 0
|
||||||
|
end = int(end_str) if end_str else total - 1
|
||||||
|
end = min(end, total - 1)
|
||||||
|
chunk = data[start : end + 1]
|
||||||
|
return Response(
|
||||||
|
content=chunk,
|
||||||
|
status_code=206,
|
||||||
|
media_type=mime,
|
||||||
|
headers={
|
||||||
|
"Content-Range": f"bytes {start}-{end}/{total}",
|
||||||
|
"Accept-Ranges": "bytes",
|
||||||
|
"Content-Length": str(len(chunk)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=data,
|
||||||
|
media_type=mime,
|
||||||
|
headers={
|
||||||
|
"Accept-Ranges": "bytes",
|
||||||
|
"Content-Length": str(total),
|
||||||
|
},
|
||||||
|
)
|
||||||
53
web/routers/state_router.py
Normal file
53
web/routers/state_router.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""GET/PUT /api/state; DELETE /api/state/{key}"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
|
from web.auth import require_auth
|
||||||
|
from web.deps import get_config, get_user_registry
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_sm(user: dict):
|
||||||
|
"""Return the per-user WorkflowStateManager, raising 503 if unavailable."""
|
||||||
|
registry = get_user_registry()
|
||||||
|
if registry is None:
|
||||||
|
raise HTTPException(503, "State manager not available")
|
||||||
|
return registry.get_state_manager(user["sub"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/state")
|
||||||
|
async def get_state(user: dict = Depends(require_auth)):
|
||||||
|
"""Return all current overrides for the authenticated user."""
|
||||||
|
sm = _get_user_sm(user)
|
||||||
|
return sm.get_overrides()
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/state")
|
||||||
|
async def put_state(body: Dict[str, Any], user: dict = Depends(require_auth)):
|
||||||
|
"""Merge override values. Pass ``null`` as a value to delete a key."""
|
||||||
|
sm = _get_user_sm(user)
|
||||||
|
for key, value in body.items():
|
||||||
|
if value is None:
|
||||||
|
sm.delete_override(key)
|
||||||
|
else:
|
||||||
|
sm.set_override(key, value)
|
||||||
|
return sm.get_overrides()
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/state/{key}")
|
||||||
|
async def delete_state_key(key: str, user: dict = Depends(require_auth)):
|
||||||
|
"""Remove a single override key, and clean up any associated slot file."""
|
||||||
|
sm = _get_user_sm(user)
|
||||||
|
sm.delete_override(key)
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
if config:
|
||||||
|
from input_image_db import deactivate_image_slot
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
deactivate_image_slot(f"{user_label}_{key}", config.comfy_input_path)
|
||||||
|
|
||||||
|
return {"ok": True, "key": key}
|
||||||
74
web/routers/status_router.py
Normal file
74
web/routers/status_router.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""GET /api/status — polling fallback for clients that can't use WebSocket"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from web.auth import require_auth
|
||||||
|
from web.deps import get_bot, get_comfy, get_config
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status")
|
||||||
|
async def get_status(_: dict = Depends(require_auth)):
|
||||||
|
"""Return a full status snapshot."""
|
||||||
|
bot = get_bot()
|
||||||
|
comfy = get_comfy()
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
snap: dict = {}
|
||||||
|
|
||||||
|
if bot is not None:
|
||||||
|
import datetime as _dt
|
||||||
|
lat = bot.latency
|
||||||
|
lat_ms = round(lat * 1000) if (lat is not None and lat != float("inf")) else 0
|
||||||
|
start = getattr(bot, "start_time", None)
|
||||||
|
uptime = ""
|
||||||
|
if start:
|
||||||
|
delta = _dt.datetime.now(_dt.timezone.utc) - start
|
||||||
|
total = int(delta.total_seconds())
|
||||||
|
h, rem = divmod(total, 3600)
|
||||||
|
m, s = divmod(rem, 60)
|
||||||
|
uptime = f"{h}h {m}m {s}s" if h else (f"{m}m {s}s" if m else f"{s}s")
|
||||||
|
snap["bot"] = {"latency_ms": lat_ms, "uptime": uptime}
|
||||||
|
|
||||||
|
if comfy is not None:
|
||||||
|
q_task = asyncio.create_task(comfy.get_comfy_queue())
|
||||||
|
conn_task = asyncio.create_task(comfy.check_connection())
|
||||||
|
q, reachable = await asyncio.gather(q_task, conn_task)
|
||||||
|
|
||||||
|
pending = len(q.get("queue_pending", [])) if q else 0
|
||||||
|
running = len(q.get("queue_running", [])) if q else 0
|
||||||
|
wm = getattr(comfy, "workflow_manager", None)
|
||||||
|
wf_loaded = wm is not None and wm.get_workflow_template() is not None
|
||||||
|
|
||||||
|
snap["comfy"] = {
|
||||||
|
"server": comfy.server_address,
|
||||||
|
"reachable": reachable,
|
||||||
|
"queue_pending": pending,
|
||||||
|
"queue_running": running,
|
||||||
|
"workflow_loaded": wf_loaded,
|
||||||
|
"last_seed": comfy.last_seed,
|
||||||
|
"total_generated": comfy.total_generated,
|
||||||
|
}
|
||||||
|
snap["overrides"] = comfy.state_manager.get_overrides()
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
from commands.server import get_service_state
|
||||||
|
service_state = await get_service_state(config.comfy_service_name)
|
||||||
|
snap["service"] = {"state": service_state}
|
||||||
|
|
||||||
|
try:
|
||||||
|
from media_uploader import get_stats as us_fn, is_running as ur_fn
|
||||||
|
us = us_fn()
|
||||||
|
snap["upload"] = {
|
||||||
|
"configured": bool(config.media_upload_user),
|
||||||
|
"running": ur_fn(),
|
||||||
|
"total_ok": us.total_ok,
|
||||||
|
"total_fail": us.total_fail,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return snap
|
||||||
175
web/routers/workflow_router.py
Normal file
175
web/routers/workflow_router.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
GET /api/workflow — current workflow info
|
||||||
|
GET /api/workflow/inputs — dynamic NodeInput list
|
||||||
|
GET /api/workflow/files — list files in workflows/
|
||||||
|
POST /api/workflow/upload — upload a workflow JSON
|
||||||
|
POST /api/workflow/load — load a workflow from workflows/
|
||||||
|
GET /api/workflow/models?type=checkpoints|loras — available models (60s TTL cache)
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||||
|
|
||||||
|
from web.auth import require_auth
|
||||||
|
from web.deps import get_comfy, get_config, get_inspector, get_user_registry
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||||
|
_WORKFLOWS_DIR = _PROJECT_ROOT / "workflows"
|
||||||
|
|
||||||
|
# Simple in-memory TTL cache for models
|
||||||
|
_models_cache: dict = {}
|
||||||
|
_MODELS_TTL = 60.0
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def get_workflow(user: dict = Depends(require_auth)):
|
||||||
|
"""Return basic info about the currently loaded workflow."""
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
registry = get_user_registry()
|
||||||
|
|
||||||
|
if registry:
|
||||||
|
template = registry.get_workflow_template(user_label)
|
||||||
|
last_wf = registry.get_state_manager(user_label).get_last_workflow_file()
|
||||||
|
else:
|
||||||
|
# Fallback to global state when registry is unavailable
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
raise HTTPException(503, "Workflow manager not available")
|
||||||
|
template = comfy.workflow_manager.get_workflow_template()
|
||||||
|
last_wf = comfy.state_manager.get_last_workflow_file()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loaded": template is not None,
|
||||||
|
"node_count": len(template) if template else 0,
|
||||||
|
"last_workflow_file": last_wf,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/inputs")
|
||||||
|
async def get_workflow_inputs(user: dict = Depends(require_auth)):
|
||||||
|
"""Return dynamic NodeInput list (common + advanced) for the current workflow."""
|
||||||
|
user_label: str = user["sub"]
|
||||||
|
inspector = get_inspector()
|
||||||
|
if inspector is None:
|
||||||
|
raise HTTPException(503, "Workflow components not available")
|
||||||
|
|
||||||
|
registry = get_user_registry()
|
||||||
|
if registry:
|
||||||
|
template = registry.get_workflow_template(user_label)
|
||||||
|
overrides = registry.get_state_manager(user_label).get_overrides()
|
||||||
|
else:
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
raise HTTPException(503, "Workflow components not available")
|
||||||
|
template = comfy.workflow_manager.get_workflow_template()
|
||||||
|
overrides = comfy.state_manager.get_overrides()
|
||||||
|
|
||||||
|
if template is None:
|
||||||
|
return {"common": [], "advanced": []}
|
||||||
|
|
||||||
|
inputs = inspector.inspect(template)
|
||||||
|
result = []
|
||||||
|
for ni in inputs:
|
||||||
|
val = overrides.get(ni.key, ni.current_value)
|
||||||
|
result.append({
|
||||||
|
"key": ni.key,
|
||||||
|
"label": ni.label,
|
||||||
|
"input_type": ni.input_type,
|
||||||
|
"current_value": val,
|
||||||
|
"node_class": ni.node_class,
|
||||||
|
"node_title": ni.node_title,
|
||||||
|
"is_common": ni.is_common,
|
||||||
|
})
|
||||||
|
common = [r for r in result if r["is_common"]]
|
||||||
|
advanced = [r for r in result if not r["is_common"]]
|
||||||
|
return {"common": common, "advanced": advanced}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files")
|
||||||
|
async def list_workflow_files(_: dict = Depends(require_auth)):
|
||||||
|
"""List .json files in the workflows/ folder."""
|
||||||
|
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
files = sorted(p.name for p in _WORKFLOWS_DIR.glob("*.json"))
|
||||||
|
return {"files": files}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload")
|
||||||
|
async def upload_workflow(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
_: dict = Depends(require_auth),
|
||||||
|
):
|
||||||
|
"""Upload a workflow JSON to the workflows/ folder."""
|
||||||
|
_WORKFLOWS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
filename = file.filename or "workflow.json"
|
||||||
|
if not filename.endswith(".json"):
|
||||||
|
filename += ".json"
|
||||||
|
data = await file.read()
|
||||||
|
try:
|
||||||
|
json.loads(data) # validate JSON
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise HTTPException(400, f"Invalid JSON: {exc}")
|
||||||
|
dest = _WORKFLOWS_DIR / filename
|
||||||
|
dest.write_bytes(data)
|
||||||
|
return {"ok": True, "filename": filename}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/load")
|
||||||
|
async def load_workflow(filename: str = Form(...), user: dict = Depends(require_auth)):
|
||||||
|
"""Load a workflow from the workflows/ folder into the user's isolated state."""
|
||||||
|
wf_path = _WORKFLOWS_DIR / filename
|
||||||
|
if not wf_path.exists():
|
||||||
|
raise HTTPException(404, f"Workflow file '{filename}' not found")
|
||||||
|
try:
|
||||||
|
with open(wf_path, "r", encoding="utf-8") as f:
|
||||||
|
workflow = json.load(f)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(500, str(exc))
|
||||||
|
|
||||||
|
registry = get_user_registry()
|
||||||
|
if registry:
|
||||||
|
registry.set_workflow(user["sub"], workflow, filename)
|
||||||
|
else:
|
||||||
|
# Fallback: update global state when registry unavailable
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
raise HTTPException(503, "ComfyUI not available")
|
||||||
|
comfy.workflow_manager.set_workflow_template(workflow)
|
||||||
|
comfy.state_manager.clear_overrides()
|
||||||
|
comfy.state_manager.set_last_workflow_file(filename)
|
||||||
|
|
||||||
|
inspector = get_inspector()
|
||||||
|
node_count = len(workflow)
|
||||||
|
inputs_count = len(inspector.inspect(workflow)) if inspector else 0
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"filename": filename,
|
||||||
|
"node_count": node_count,
|
||||||
|
"inputs_count": inputs_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
async def get_models(type: str = "checkpoints", _: dict = Depends(require_auth)):
|
||||||
|
"""Return available model names from ComfyUI (60s TTL cache)."""
|
||||||
|
global _models_cache
|
||||||
|
now = time.time()
|
||||||
|
cache_key = type
|
||||||
|
cached = _models_cache.get(cache_key)
|
||||||
|
if cached and (now - cached["ts"]) < _MODELS_TTL:
|
||||||
|
return {"type": type, "models": cached["models"]}
|
||||||
|
|
||||||
|
comfy = get_comfy()
|
||||||
|
if comfy is None:
|
||||||
|
raise HTTPException(503, "ComfyUI not available")
|
||||||
|
models = await comfy.get_models(type)
|
||||||
|
_models_cache[cache_key] = {"models": models, "ts": now}
|
||||||
|
return {"type": type, "models": models}
|
||||||
61
web/routers/ws_router.py
Normal file
61
web/routers/ws_router.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""WebSocket /ws?token=<jwt> — real-time event stream"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
|
from web.auth import verify_ws_token
|
||||||
|
from web.ws_bus import get_bus
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/ws")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket, token: str = ""):
|
||||||
|
"""
|
||||||
|
Authenticate via JWT query param or ttb_session cookie, then stream events from WSBus.
|
||||||
|
|
||||||
|
Events common to all users: status_snapshot, queue_update, node_executing, server_state
|
||||||
|
Events private to submitter: generation_complete, generation_error
|
||||||
|
"""
|
||||||
|
payload = verify_ws_token(token)
|
||||||
|
if payload is None:
|
||||||
|
# Fallback: browsers send cookies automatically with WebSocket connections
|
||||||
|
cookie_token = websocket.cookies.get("ttb_session", "")
|
||||||
|
payload = verify_ws_token(cookie_token)
|
||||||
|
if payload is None:
|
||||||
|
await websocket.close(code=4001, reason="Unauthorized")
|
||||||
|
return
|
||||||
|
|
||||||
|
user_label: str = payload.get("sub", "anonymous")
|
||||||
|
bus = get_bus()
|
||||||
|
queue = bus.subscribe(user_label)
|
||||||
|
await websocket.accept()
|
||||||
|
logger.info("WS connected: user=%s", user_label)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Wait for an event from the bus
|
||||||
|
try:
|
||||||
|
frame = await asyncio.wait_for(queue.get(), timeout=30.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Send a keepalive ping
|
||||||
|
try:
|
||||||
|
await websocket.send_text('{"type":"ping"}')
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await websocket.send_text(frame)
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
bus.unsubscribe(user_label, queue)
|
||||||
|
logger.info("WS disconnected: user=%s", user_label)
|
||||||
122
web/ws_bus.py
Normal file
122
web/ws_bus.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
web/ws_bus.py
|
||||||
|
=============
|
||||||
|
|
||||||
|
In-process WebSocket event bus.
|
||||||
|
|
||||||
|
All connected web clients share a single WSBus instance. Events are
|
||||||
|
delivered per-user (private results) or to all users (shared status).
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
bus = WSBus()
|
||||||
|
|
||||||
|
# Subscribe (returns a queue; caller reads from it)
|
||||||
|
q = bus.subscribe("alice")
|
||||||
|
|
||||||
|
# Broadcast to all
|
||||||
|
await bus.broadcast("status_snapshot", {...})
|
||||||
|
|
||||||
|
# Broadcast to one user (all their open tabs)
|
||||||
|
await bus.broadcast_to_user("alice", "generation_complete", {...})
|
||||||
|
|
||||||
|
# Unsubscribe when WS disconnects
|
||||||
|
bus.unsubscribe("alice", q)
|
||||||
|
|
||||||
|
Event frame format sent on wire:
|
||||||
|
{"type": "event_name", "data": {...}, "ts": 1234567890.123}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, Set
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WSBus:
|
||||||
|
"""
|
||||||
|
Per-user broadcast bus backed by asyncio queues.
|
||||||
|
|
||||||
|
Thread-safe as long as all callers run in the same event loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# user_label → set of asyncio.Queue
|
||||||
|
self._clients: Dict[str, Set[asyncio.Queue]] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Subscription lifecycle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def subscribe(self, user_label: str) -> asyncio.Queue:
|
||||||
|
"""Register a new client connection. Returns the queue to read from."""
|
||||||
|
q: asyncio.Queue = asyncio.Queue(maxsize=256)
|
||||||
|
self._clients.setdefault(user_label, set()).add(q)
|
||||||
|
logger.debug("WSBus: %s subscribed (%d queues)", user_label,
|
||||||
|
len(self._clients[user_label]))
|
||||||
|
return q
|
||||||
|
|
||||||
|
def unsubscribe(self, user_label: str, queue: asyncio.Queue) -> None:
|
||||||
|
"""Remove a client connection."""
|
||||||
|
queues = self._clients.get(user_label, set())
|
||||||
|
queues.discard(queue)
|
||||||
|
if not queues:
|
||||||
|
self._clients.pop(user_label, None)
|
||||||
|
logger.debug("WSBus: %s unsubscribed", user_label)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connected_users(self) -> list[str]:
|
||||||
|
"""List of user labels with at least one active connection."""
|
||||||
|
return list(self._clients.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_connections(self) -> int:
|
||||||
|
return sum(len(qs) for qs in self._clients.values())
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Broadcasting
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _frame(self, event_type: str, data: Any) -> str:
|
||||||
|
return json.dumps({"type": event_type, "data": data, "ts": time.time()})
|
||||||
|
|
||||||
|
async def broadcast(self, event_type: str, data: Any) -> None:
|
||||||
|
"""Send an event to ALL connected clients."""
|
||||||
|
frame = self._frame(event_type, data)
|
||||||
|
for queues in list(self._clients.values()):
|
||||||
|
for q in list(queues):
|
||||||
|
try:
|
||||||
|
q.put_nowait(frame)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning("WSBus: queue full, dropping %s event", event_type)
|
||||||
|
|
||||||
|
async def broadcast_to_user(
|
||||||
|
self, user_label: str, event_type: str, data: Any
|
||||||
|
) -> None:
|
||||||
|
"""Send an event to all connections belonging to *user_label*."""
|
||||||
|
queues = self._clients.get(user_label, set())
|
||||||
|
if not queues:
|
||||||
|
logger.debug("WSBus: no clients for user '%s', dropping %s", user_label, event_type)
|
||||||
|
return
|
||||||
|
frame = self._frame(event_type, data)
|
||||||
|
for q in list(queues):
|
||||||
|
try:
|
||||||
|
q.put_nowait(frame)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning("WSBus: queue full for %s, dropping %s", user_label, event_type)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton (set by web/app.py)
|
||||||
|
_bus: WSBus | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_bus() -> WSBus:
|
||||||
|
global _bus
|
||||||
|
if _bus is None:
|
||||||
|
_bus = WSBus()
|
||||||
|
return _bus
|
||||||
397
workflow_inspector.py
Normal file
397
workflow_inspector.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
"""
|
||||||
|
workflow_inspector.py
|
||||||
|
=====================
|
||||||
|
|
||||||
|
Dynamic workflow node inspection and injection for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
Replaces the hardcoded node-finding methods in WorkflowManager with a
|
||||||
|
general-purpose inspector that works with any ComfyUI workflow. The
|
||||||
|
inspector discovers injectable inputs at load time by walking the workflow
|
||||||
|
JSON, classifying each scalar input by class_type + input_name, and
|
||||||
|
assigning stable human-readable keys (e.g. ``"prompt"``, ``"seed"``,
|
||||||
|
``"input_image"``) that can be stored in WorkflowStateManager and used by
|
||||||
|
both Discord commands and the web UI.
|
||||||
|
|
||||||
|
Key-assignment rules
|
||||||
|
--------------------
|
||||||
|
- CLIPTextEncode / text → ``"prompt"`` / ``"negative_prompt"`` / ``"text_{node_id}"``
|
||||||
|
- LoadImage / image → ``"input_image"`` (first), ``{title_slug}`` (subsequent)
|
||||||
|
- KSampler* / seed → ``"seed"``
|
||||||
|
- KSampler / steps|cfg|sampler_name|scheduler|denoise → same name
|
||||||
|
- EmptyLatentImage / width|height → ``"width"`` / ``"height"``
|
||||||
|
- CheckpointLoaderSimple / ckpt_name → ``"checkpoint"``
|
||||||
|
- LoraLoader / lora_name → ``"lora_{node_id}"``
|
||||||
|
- Other scalars → ``"{class_slug}_{node_id}_{input_name}"``
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _slugify(text: str) -> str:
|
||||||
|
"""Convert arbitrary text to a lowercase underscore key."""
|
||||||
|
text = text.lower()
|
||||||
|
text = re.sub(r"[^a-z0-9]+", "_", text)
|
||||||
|
return text.strip("_") or "node"
|
||||||
|
|
||||||
|
|
||||||
|
def _numeric_sort_key(node_id: str):
|
||||||
|
"""Sort node IDs numerically where possible."""
|
||||||
|
try:
|
||||||
|
return (0, int(node_id))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return (1, node_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# NodeInput dataclass
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeInput:
|
||||||
|
"""
|
||||||
|
Descriptor for a single injectable input within a ComfyUI workflow.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
node_id : str
|
||||||
|
The node's key in the workflow dict.
|
||||||
|
node_class : str
|
||||||
|
``class_type`` of the node (e.g. ``"KSampler"``).
|
||||||
|
node_title : str
|
||||||
|
Value of ``_meta.title`` for the node (may be empty).
|
||||||
|
input_name : str
|
||||||
|
The input field name within the node's ``inputs`` dict.
|
||||||
|
input_type : str
|
||||||
|
Semantic type: ``"text"``, ``"seed"``, ``"image"``, ``"integer"``,
|
||||||
|
``"float"``, ``"string"``, ``"checkpoint"``, ``"lora"``.
|
||||||
|
current_value : Any
|
||||||
|
The value currently stored in the workflow template.
|
||||||
|
label : str
|
||||||
|
Human-readable display label (``"NodeTitle / input_name"``).
|
||||||
|
key : str
|
||||||
|
Stable short key used by Discord commands and state storage.
|
||||||
|
is_common : bool
|
||||||
|
True for prompts, seeds, and image inputs (shown prominently in UI).
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_id: str
|
||||||
|
node_class: str
|
||||||
|
node_title: str
|
||||||
|
input_name: str
|
||||||
|
input_type: str
|
||||||
|
current_value: Any
|
||||||
|
label: str
|
||||||
|
key: str
|
||||||
|
is_common: bool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# WorkflowInspector
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class WorkflowInspector:
|
||||||
|
"""
|
||||||
|
Inspects and modifies ComfyUI workflow JSON.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
inspector = WorkflowInspector()
|
||||||
|
inputs = inspector.inspect(workflow)
|
||||||
|
modified, applied = inspector.inject_overrides(workflow, {"prompt": "a cat"})
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def inspect(self, workflow: dict[str, Any]) -> list[NodeInput]:
|
||||||
|
"""
|
||||||
|
Walk a workflow and return all injectable scalar inputs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
workflow : dict
|
||||||
|
ComfyUI workflow in API format (node_id → node_dict).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[NodeInput]
|
||||||
|
All injectable inputs, common ones first, advanced sorted by key.
|
||||||
|
"""
|
||||||
|
inputs: list[NodeInput] = []
|
||||||
|
load_image_count = 0 # for unique LoadImage key assignment
|
||||||
|
|
||||||
|
for node_id, node in sorted(workflow.items(), key=lambda kv: _numeric_sort_key(kv[0])):
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
continue
|
||||||
|
class_type: str = node.get("class_type", "")
|
||||||
|
title: str = node.get("_meta", {}).get("title", "") or ""
|
||||||
|
node_inputs = node.get("inputs", {})
|
||||||
|
if not isinstance(node_inputs, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for input_name, value in node_inputs.items():
|
||||||
|
# Skip node-reference inputs (they are lists like [node_id, output_slot])
|
||||||
|
if isinstance(value, list):
|
||||||
|
continue
|
||||||
|
# Skip None values — no useful type info
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ni = self._classify_input(
|
||||||
|
node_id=node_id,
|
||||||
|
class_type=class_type,
|
||||||
|
title=title,
|
||||||
|
input_name=input_name,
|
||||||
|
value=value,
|
||||||
|
load_image_count=load_image_count,
|
||||||
|
)
|
||||||
|
if ni is not None:
|
||||||
|
if ni.input_type == "image":
|
||||||
|
load_image_count += 1
|
||||||
|
inputs.append(ni)
|
||||||
|
|
||||||
|
# Sort: common first, then advanced (both groups sorted by key for stability)
|
||||||
|
inputs.sort(key=lambda x: (0 if x.is_common else 1, x.key))
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def inject_overrides(
|
||||||
|
self,
|
||||||
|
workflow: dict[str, Any],
|
||||||
|
overrides: dict[str, Any],
|
||||||
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Deep-copy a workflow and inject override values.
|
||||||
|
|
||||||
|
Seeds that are absent from *overrides* or set to ``-1`` are
|
||||||
|
auto-randomized. All other injectable keys found via
|
||||||
|
:meth:`inspect` are updated if present in *overrides*.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
workflow : dict
|
||||||
|
The workflow template (not mutated).
|
||||||
|
overrides : dict
|
||||||
|
Mapping of ``NodeInput.key → value`` to inject.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[dict, dict]
|
||||||
|
``(modified_workflow, applied_values)`` where *applied_values*
|
||||||
|
maps each key that was actually written (including auto-generated
|
||||||
|
seeds) to the value that was used.
|
||||||
|
"""
|
||||||
|
wf = copy.deepcopy(workflow)
|
||||||
|
applied: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Inspect the deep copy to build key → [(node_id, input_name), …]
|
||||||
|
inputs = self.inspect(wf)
|
||||||
|
|
||||||
|
# Group targets by key
|
||||||
|
key_targets: dict[str, list[tuple[str, str]]] = {}
|
||||||
|
key_itype: dict[str, str] = {}
|
||||||
|
for ni in inputs:
|
||||||
|
if ni.key not in key_targets:
|
||||||
|
key_targets[ni.key] = []
|
||||||
|
key_itype[ni.key] = ni.input_type
|
||||||
|
key_targets[ni.key].append((ni.node_id, ni.input_name))
|
||||||
|
|
||||||
|
for key, targets in key_targets.items():
|
||||||
|
itype = key_itype[key]
|
||||||
|
override_val = overrides.get(key)
|
||||||
|
|
||||||
|
if itype == "seed":
|
||||||
|
# -1 sentinel or absent → auto-randomize
|
||||||
|
if override_val is None or override_val == -1:
|
||||||
|
seed = random.randint(0, 2 ** 32 - 1)
|
||||||
|
else:
|
||||||
|
seed = int(override_val)
|
||||||
|
for node_id, input_name in targets:
|
||||||
|
wf[node_id]["inputs"][input_name] = seed
|
||||||
|
applied[key] = seed
|
||||||
|
|
||||||
|
elif override_val is not None:
|
||||||
|
for node_id, input_name in targets:
|
||||||
|
wf[node_id]["inputs"][input_name] = override_val
|
||||||
|
applied[key] = override_val
|
||||||
|
|
||||||
|
return wf, applied
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _classify_input(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
node_id: str,
|
||||||
|
class_type: str,
|
||||||
|
title: str,
|
||||||
|
input_name: str,
|
||||||
|
value: Any,
|
||||||
|
load_image_count: int,
|
||||||
|
) -> Optional[NodeInput]:
|
||||||
|
"""
|
||||||
|
Return a NodeInput for one field, or None to skip it entirely.
|
||||||
|
"""
|
||||||
|
title_display = title or class_type
|
||||||
|
|
||||||
|
# ---- CLIPTextEncode → positive/negative prompt ----
|
||||||
|
if class_type == "CLIPTextEncode" and input_name == "text":
|
||||||
|
t = title.lower()
|
||||||
|
if "positive" in t:
|
||||||
|
key = "prompt"
|
||||||
|
elif "negative" in t:
|
||||||
|
key = "negative_prompt"
|
||||||
|
else:
|
||||||
|
key = f"text_{node_id}"
|
||||||
|
return NodeInput(
|
||||||
|
node_id=node_id,
|
||||||
|
node_class=class_type,
|
||||||
|
node_title=title,
|
||||||
|
input_name=input_name,
|
||||||
|
input_type="text",
|
||||||
|
current_value=value,
|
||||||
|
label=f"{title_display} / {input_name}",
|
||||||
|
key=key,
|
||||||
|
is_common=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- LoadImage → input image ----
|
||||||
|
if class_type == "LoadImage" and input_name == "image":
|
||||||
|
if load_image_count == 0:
|
||||||
|
key = "input_image"
|
||||||
|
else:
|
||||||
|
slug = _slugify(title) if title else f"image_{node_id}"
|
||||||
|
key = slug or f"image_{node_id}"
|
||||||
|
return NodeInput(
|
||||||
|
node_id=node_id,
|
||||||
|
node_class=class_type,
|
||||||
|
node_title=title,
|
||||||
|
input_name=input_name,
|
||||||
|
input_type="image",
|
||||||
|
current_value=value,
|
||||||
|
label=f"{title_display} / {input_name}",
|
||||||
|
key=key,
|
||||||
|
is_common=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- KSampler / KSamplerAdvanced ----
|
||||||
|
if class_type in ("KSampler", "KSamplerAdvanced"):
|
||||||
|
if input_name in ("seed", "noise_seed"):
|
||||||
|
return NodeInput(
|
||||||
|
node_id=node_id,
|
||||||
|
node_class=class_type,
|
||||||
|
node_title=title,
|
||||||
|
input_name=input_name,
|
||||||
|
input_type="seed",
|
||||||
|
current_value=value,
|
||||||
|
label=f"{title_display} / seed",
|
||||||
|
key="seed",
|
||||||
|
is_common=True,
|
||||||
|
)
|
||||||
|
_ksampler_advanced = {
|
||||||
|
"steps": ("integer", "steps", False),
|
||||||
|
"cfg": ("float", "cfg", False),
|
||||||
|
"sampler_name": ("string", "sampler_name", False),
|
||||||
|
"scheduler": ("string", "scheduler", False),
|
||||||
|
"denoise": ("float", "denoise", False),
|
||||||
|
}
|
||||||
|
if input_name in _ksampler_advanced:
|
||||||
|
itype, key, is_common = _ksampler_advanced[input_name]
|
||||||
|
return NodeInput(
|
||||||
|
node_id=node_id,
|
||||||
|
node_class=class_type,
|
||||||
|
node_title=title,
|
||||||
|
input_name=input_name,
|
||||||
|
input_type=itype,
|
||||||
|
current_value=value,
|
||||||
|
label=f"{title_display} / {input_name}",
|
||||||
|
key=key,
|
||||||
|
is_common=is_common,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- EmptyLatentImage ----
|
||||||
|
if class_type == "EmptyLatentImage" and input_name in ("width", "height"):
|
||||||
|
return NodeInput(
|
||||||
|
node_id=node_id,
|
||||||
|
node_class=class_type,
|
||||||
|
node_title=title,
|
||||||
|
input_name=input_name,
|
||||||
|
input_type="integer",
|
||||||
|
current_value=value,
|
||||||
|
label=f"{title_display} / {input_name}",
|
||||||
|
key=input_name,
|
||||||
|
is_common=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- CheckpointLoaderSimple ----
|
||||||
|
if class_type == "CheckpointLoaderSimple" and input_name == "ckpt_name":
|
||||||
|
return NodeInput(
|
||||||
|
node_id=node_id,
|
||||||
|
node_class=class_type,
|
||||||
|
node_title=title,
|
||||||
|
input_name=input_name,
|
||||||
|
input_type="checkpoint",
|
||||||
|
current_value=value,
|
||||||
|
label=f"{title_display} / checkpoint",
|
||||||
|
key="checkpoint",
|
||||||
|
is_common=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- LoraLoader ----
|
||||||
|
if class_type == "LoraLoader" and input_name == "lora_name":
|
||||||
|
return NodeInput(
|
||||||
|
node_id=node_id,
|
||||||
|
node_class=class_type,
|
||||||
|
node_title=title,
|
||||||
|
input_name=input_name,
|
||||||
|
input_type="lora",
|
||||||
|
current_value=value,
|
||||||
|
label=f"{title_display} / lora",
|
||||||
|
key=f"lora_{node_id}",
|
||||||
|
is_common=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Skip non-scalar or already-handled classes ----
|
||||||
|
_handled_classes = {
|
||||||
|
"CLIPTextEncode", "LoadImage", "KSampler", "KSamplerAdvanced",
|
||||||
|
"EmptyLatentImage", "CheckpointLoaderSimple", "LoraLoader",
|
||||||
|
}
|
||||||
|
if class_type in _handled_classes:
|
||||||
|
return None # unrecognised field for a known class → skip
|
||||||
|
|
||||||
|
# ---- Generic scalar fallback ----
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return None # booleans aren't useful override targets
|
||||||
|
if isinstance(value, int):
|
||||||
|
itype = "integer"
|
||||||
|
elif isinstance(value, float):
|
||||||
|
itype = "float"
|
||||||
|
elif isinstance(value, str):
|
||||||
|
itype = "string"
|
||||||
|
else:
|
||||||
|
return None # dicts, etc. — skip
|
||||||
|
|
||||||
|
key = f"{_slugify(class_type)}_{node_id}_{input_name}"
|
||||||
|
return NodeInput(
|
||||||
|
node_id=node_id,
|
||||||
|
node_class=class_type,
|
||||||
|
node_title=title,
|
||||||
|
input_name=input_name,
|
||||||
|
input_type=itype,
|
||||||
|
current_value=value,
|
||||||
|
label=f"{title_display} / {input_name}",
|
||||||
|
key=key,
|
||||||
|
is_common=False,
|
||||||
|
)
|
||||||
56
workflow_manager.py
Normal file
56
workflow_manager.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
workflow_manager.py
|
||||||
|
===================
|
||||||
|
|
||||||
|
Workflow template storage for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
This module provides a WorkflowManager class responsible solely for holding
|
||||||
|
the loaded workflow template. Node inspection and injection are handled by
|
||||||
|
:class:`~workflow_inspector.WorkflowInspector`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowManager:
|
||||||
|
"""
|
||||||
|
Stores and provides access to the current workflow template.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
workflow_template : Optional[Dict[str, Any]]
|
||||||
|
An initial workflow dict. If None, no template is loaded and one
|
||||||
|
must be set via :meth:`set_workflow_template`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, workflow_template: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
self._workflow_template = workflow_template
|
||||||
|
|
||||||
|
def set_workflow_template(self, workflow: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Set or replace the workflow template.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
workflow : Dict[str, Any]
|
||||||
|
Workflow dictionary in ComfyUI API format.
|
||||||
|
"""
|
||||||
|
self._workflow_template = workflow
|
||||||
|
logger.info("Workflow template updated (%d nodes)", len(workflow))
|
||||||
|
|
||||||
|
def get_workflow_template(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Return the current workflow template.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Optional[Dict[str, Any]]
|
||||||
|
The loaded template, or None if none has been set.
|
||||||
|
"""
|
||||||
|
return self._workflow_template
|
||||||
252
workflow_state.py
Normal file
252
workflow_state.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""
|
||||||
|
workflow_state.py
|
||||||
|
=================
|
||||||
|
|
||||||
|
Workflow state management for the Discord ComfyUI bot.
|
||||||
|
|
||||||
|
This module provides a WorkflowStateManager class that stores runtime
|
||||||
|
overrides for workflow parameters (prompt, negative_prompt, input_image,
|
||||||
|
seed, steps, cfg, …) in a generic key-value dict. Any NodeInput.key
|
||||||
|
produced by WorkflowInspector can be set as an override.
|
||||||
|
|
||||||
|
The old fixed fields (prompt, negative_prompt, input_image, seed) are
|
||||||
|
preserved as convenience wrappers for backward compatibility with existing
|
||||||
|
Discord commands and status_monitor.
|
||||||
|
|
||||||
|
State file format (current-workflow-changes.json)::
|
||||||
|
|
||||||
|
{
|
||||||
|
"overrides": {
|
||||||
|
"prompt": "a beautiful landscape",
|
||||||
|
"seed": 42,
|
||||||
|
...
|
||||||
|
},
|
||||||
|
"last_workflow_file": "my_workflow.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
The old flat-key format is migrated automatically on first load.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowStateManager:
|
||||||
|
"""
|
||||||
|
Manages runtime workflow overrides in memory with optional persistence.
|
||||||
|
|
||||||
|
Override keys correspond to ``NodeInput.key`` values discovered by
|
||||||
|
:class:`~workflow_inspector.WorkflowInspector`. Common well-known keys
|
||||||
|
are ``"prompt"``, ``"negative_prompt"``, ``"input_image"``, ``"seed"``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
state_file : Optional[str]
|
||||||
|
Path to a JSON file for persisting overrides. Loaded on init if the
|
||||||
|
file exists; auto-saved on every change.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state_file: Optional[str] = None) -> None:
|
||||||
|
self._overrides: Dict[str, Any] = {}
|
||||||
|
self._last_workflow_file: Optional[str] = None
|
||||||
|
self._state_file = state_file
|
||||||
|
|
||||||
|
if self._state_file:
|
||||||
|
self._load_from_file()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Persistence
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _load_from_file(self) -> None:
|
||||||
|
"""Load state from the configured JSON file if it exists."""
|
||||||
|
if not self._state_file:
|
||||||
|
return
|
||||||
|
state_path = Path(self._state_file)
|
||||||
|
if not state_path.exists():
|
||||||
|
logger.debug("State file %s not found, using empty state", self._state_file)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with open(state_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
# New format: {"overrides": {...}, "last_workflow_file": ...}
|
||||||
|
if "overrides" in data and isinstance(data["overrides"], dict):
|
||||||
|
self._overrides = data["overrides"]
|
||||||
|
self._last_workflow_file = data.get("last_workflow_file")
|
||||||
|
else:
|
||||||
|
# Migrate old flat format: {"prompt": ..., "negative_prompt": ..., ...}
|
||||||
|
self._overrides = {
|
||||||
|
k: v for k, v in data.items()
|
||||||
|
if v is not None and k not in ("last_workflow_file",)
|
||||||
|
}
|
||||||
|
self._last_workflow_file = data.get("last_workflow_file")
|
||||||
|
logger.info("Loaded workflow state from %s", self._state_file)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to load state from %s: %s", self._state_file, exc)
|
||||||
|
|
||||||
|
def save_to_file(self) -> None:
|
||||||
|
"""
|
||||||
|
Persist current overrides and last_workflow_file to the state JSON file.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
RuntimeError
|
||||||
|
If no state file was configured.
|
||||||
|
"""
|
||||||
|
if not self._state_file:
|
||||||
|
raise RuntimeError("Cannot save state: no state file configured")
|
||||||
|
try:
|
||||||
|
data = {
|
||||||
|
"overrides": self._overrides,
|
||||||
|
"last_workflow_file": self._last_workflow_file,
|
||||||
|
}
|
||||||
|
with open(self._state_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=4)
|
||||||
|
logger.debug("Saved workflow state to %s", self._state_file)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to save state to %s: %s", self._state_file, exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _autosave(self) -> None:
|
||||||
|
"""Save to file silently if a state file is configured."""
|
||||||
|
if self._state_file:
|
||||||
|
try:
|
||||||
|
self.save_to_file()
|
||||||
|
except Exception:
|
||||||
|
pass # already logged inside save_to_file
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Generic override API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_overrides(self) -> Dict[str, Any]:
|
||||||
|
"""Return a shallow copy of all current overrides."""
|
||||||
|
return self._overrides.copy()
|
||||||
|
|
||||||
|
def set_override(self, key: str, value: Any) -> None:
|
||||||
|
"""Set a single override key and auto-save."""
|
||||||
|
self._overrides[key] = value
|
||||||
|
self._autosave()
|
||||||
|
|
||||||
|
def delete_override(self, key: str) -> None:
|
||||||
|
"""Remove a single override key (no-op if absent) and auto-save."""
|
||||||
|
self._overrides.pop(key, None)
|
||||||
|
self._autosave()
|
||||||
|
|
||||||
|
def clear_overrides(self) -> None:
|
||||||
|
"""Remove all override keys and auto-save."""
|
||||||
|
self._overrides = {}
|
||||||
|
self._autosave()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Last-workflow-file tracking
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_last_workflow_file(self) -> Optional[str]:
|
||||||
|
"""Return the last loaded workflow filename (for auto-load on restart)."""
|
||||||
|
return self._last_workflow_file
|
||||||
|
|
||||||
|
def set_last_workflow_file(self, filename: Optional[str]) -> None:
|
||||||
|
"""Record the last loaded workflow filename and auto-save."""
|
||||||
|
self._last_workflow_file = filename
|
||||||
|
self._autosave()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Backward-compat: old get_changes / set_changes API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_changes(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Alias for :meth:`get_overrides` retained for backward compatibility.
|
||||||
|
|
||||||
|
Returns a dict that always has ``prompt``, ``negative_prompt``,
|
||||||
|
``input_image``, and ``seed`` keys (value is ``None`` when unset)
|
||||||
|
so existing callers that rely on these specific keys still work.
|
||||||
|
"""
|
||||||
|
base: Dict[str, Any] = {
|
||||||
|
"prompt": None,
|
||||||
|
"negative_prompt": None,
|
||||||
|
"input_image": None,
|
||||||
|
"seed": None,
|
||||||
|
}
|
||||||
|
base.update(self._overrides)
|
||||||
|
return base
|
||||||
|
|
||||||
|
def set_changes(self, changes: Dict[str, Any], merge: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Set multiple overrides at once.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
changes : dict
|
||||||
|
Key-value pairs to apply.
|
||||||
|
merge : bool
|
||||||
|
If True (default), merge with existing overrides.
|
||||||
|
If False, replace all overrides with these values (``None``
|
||||||
|
values are excluded).
|
||||||
|
"""
|
||||||
|
if merge:
|
||||||
|
for k, v in changes.items():
|
||||||
|
if v is not None:
|
||||||
|
self._overrides[k] = v
|
||||||
|
else:
|
||||||
|
self._overrides = {k: v for k, v in changes.items() if v is not None}
|
||||||
|
self._autosave()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Convenience setters / getters for well-known keys
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def set_prompt(self, prompt: str) -> None:
|
||||||
|
"""Set the positive prompt override."""
|
||||||
|
self.set_override("prompt", prompt)
|
||||||
|
|
||||||
|
def get_prompt(self) -> Optional[str]:
|
||||||
|
"""Return the positive prompt override, or None if not set."""
|
||||||
|
return self._overrides.get("prompt")
|
||||||
|
|
||||||
|
def set_negative_prompt(self, negative_prompt: str) -> None:
|
||||||
|
"""Set the negative prompt override."""
|
||||||
|
self.set_override("negative_prompt", negative_prompt)
|
||||||
|
|
||||||
|
def get_negative_prompt(self) -> Optional[str]:
|
||||||
|
"""Return the negative prompt override, or None if not set."""
|
||||||
|
return self._overrides.get("negative_prompt")
|
||||||
|
|
||||||
|
def set_input_image(self, input_image: str) -> None:
|
||||||
|
"""Set the input image override (filename)."""
|
||||||
|
self.set_override("input_image", input_image)
|
||||||
|
|
||||||
|
def get_input_image(self) -> Optional[str]:
|
||||||
|
"""Return the input image override, or None if not set."""
|
||||||
|
return self._overrides.get("input_image")
|
||||||
|
|
||||||
|
def set_seed(self, seed: int) -> None:
|
||||||
|
"""Pin a specific seed for deterministic generation."""
|
||||||
|
self.set_override("seed", seed)
|
||||||
|
|
||||||
|
def get_seed(self) -> Optional[int]:
|
||||||
|
"""Return the pinned seed, or None if randomising each run."""
|
||||||
|
return self._overrides.get("seed")
|
||||||
|
|
||||||
|
def clear_seed(self) -> None:
|
||||||
|
"""Clear the pinned seed, reverting to random generation."""
|
||||||
|
self.delete_override("seed")
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Reset all overrides (alias for :meth:`clear_overrides`)."""
|
||||||
|
self.clear_overrides()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"WorkflowStateManager("
|
||||||
|
f"overrides={self._overrides!r}, "
|
||||||
|
f"last_workflow_file={self._last_workflow_file!r})"
|
||||||
|
)
|
||||||
0
workflows/.gitkeep
Normal file
0
workflows/.gitkeep
Normal file
Reference in New Issue
Block a user