Files
comfy-discord-web/web/login_guard.py
Khoa (Revenovich) Tran Gia 1ed3c9ec4b Initial commit — ComfyUI Discord bot + web UI
Full source for the-third-rev: Discord bot (discord.py), FastAPI web UI
(React/TS/Vite/Tailwind), ComfyUI integration, generation history DB,
preset manager, workflow inspector, and all supporting modules.

Excluded from tracking: .env, invite_tokens.json, *.db (SQLite),
current-workflow-changes.json, user_settings/, presets/, logs/,
web-static/ (build output), frontend/node_modules/.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 09:55:48 +07:00

147 lines
4.9 KiB
Python

"""
web/login_guard.py
==================
IP-based brute-force protection for login endpoints.
Tracks failed login attempts per IP in a rolling time window and issues a
temporary ban when the threshold is exceeded. Uses only stdlib — no new
pip packages required.
Usage
-----
from web.login_guard import get_guard, get_real_ip
@router.post("/login")
async def login(request: Request, body: LoginRequest, response: Response):
ip = get_real_ip(request)
get_guard().check(ip) # raises 429 if locked out
...
if failure:
get_guard().record_failure(ip)
raise HTTPException(401, ...)
get_guard().record_success(ip)
...
"""
from __future__ import annotations
import logging
import time
from collections import defaultdict
from typing import Dict, List
from fastapi import HTTPException, Request, status
logger = logging.getLogger(__name__)
def get_real_ip(request: Request) -> str:
"""Return the real client IP, honouring Cloudflare and common proxy headers.
Priority:
1. ``CF-Connecting-IP`` (set by Cloudflare)
2. ``X-Real-IP`` (set by nginx/traefik)
3. ``request.client.host`` (direct connection fallback)
"""
cf_ip = request.headers.get("CF-Connecting-IP", "").strip()
if cf_ip:
return cf_ip
real_ip = request.headers.get("X-Real-IP", "").strip()
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
class BruteForceGuard:
"""Rolling-window failure counter with automatic IP bans.
All state is in-process memory. A restart clears all bans and counters,
which is acceptable — a brief restart already provides a natural backoff
for a legitimate attacker.
"""
WINDOW_SECS = 600 # rolling window: 10 minutes
MAX_FAILURES = 10 # max failures before ban
BAN_SECS = 3600 # ban duration: 1 hour
def __init__(self) -> None:
# ip → list of failure timestamps (epoch floats)
self._failures: Dict[str, List[float]] = defaultdict(list)
# ip → ban expiry timestamp
self._ban_until: Dict[str, float] = {}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def check(self, ip: str) -> None:
"""Raise HTTP 429 if the IP is banned or has exceeded the failure threshold.
Call this *before* doing any credential work so the lockout is
evaluated even when the request body is malformed.
"""
now = time.time()
# Active ban?
ban_expiry = self._ban_until.get(ip, 0)
if ban_expiry > now:
logger.warning("login_guard: blocked request from banned ip=%s (ban expires in %.0fs)", ip, ban_expiry - now)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many failed attempts. Try again later.",
)
# Failure count within rolling window
cutoff = now - self.WINDOW_SECS
recent = [t for t in self._failures[ip] if t > cutoff]
self._failures[ip] = recent # prune stale entries while we're here
if len(recent) >= self.MAX_FAILURES:
# Threshold just reached — apply ban now
self._ban_until[ip] = now + self.BAN_SECS
logger.warning(
"login_guard: threshold reached, banning ip=%s for %ds",
ip, self.BAN_SECS,
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many failed attempts. Try again later.",
)
def record_failure(self, ip: str) -> None:
"""Record a failed login attempt for the given IP."""
now = time.time()
cutoff = now - self.WINDOW_SECS
recent = [t for t in self._failures[ip] if t > cutoff]
recent.append(now)
self._failures[ip] = recent
count = len(recent)
logger.warning("login_guard: failure #%d from ip=%s", count, ip)
if count >= self.MAX_FAILURES:
self._ban_until[ip] = now + self.BAN_SECS
logger.warning(
"login_guard: threshold reached, banning ip=%s for %ds",
ip, self.BAN_SECS,
)
def record_success(self, ip: str) -> None:
"""Clear failure history and any active ban for the given IP."""
self._failures.pop(ip, None)
self._ban_until.pop(ip, None)
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_guard: BruteForceGuard | None = None
def get_guard() -> BruteForceGuard:
"""Return the shared BruteForceGuard singleton (created on first call)."""
global _guard
if _guard is None:
_guard = BruteForceGuard()
return _guard