""" web/login_guard.py ================== IP-based brute-force protection for login endpoints. Tracks failed login attempts per IP in a rolling time window and issues a temporary ban when the threshold is exceeded. Uses only stdlib — no new pip packages required. Usage ----- from web.login_guard import get_guard, get_real_ip @router.post("/login") async def login(request: Request, body: LoginRequest, response: Response): ip = get_real_ip(request) get_guard().check(ip) # raises 429 if locked out ... if failure: get_guard().record_failure(ip) raise HTTPException(401, ...) get_guard().record_success(ip) ... """ from __future__ import annotations import logging import time from collections import defaultdict from typing import Dict, List from fastapi import HTTPException, Request, status logger = logging.getLogger(__name__) def get_real_ip(request: Request) -> str: """Return the real client IP, honouring Cloudflare and common proxy headers. Priority: 1. ``CF-Connecting-IP`` (set by Cloudflare) 2. ``X-Real-IP`` (set by nginx/traefik) 3. ``request.client.host`` (direct connection fallback) """ cf_ip = request.headers.get("CF-Connecting-IP", "").strip() if cf_ip: return cf_ip real_ip = request.headers.get("X-Real-IP", "").strip() if real_ip: return real_ip return request.client.host if request.client else "unknown" class BruteForceGuard: """Rolling-window failure counter with automatic IP bans. All state is in-process memory. A restart clears all bans and counters, which is acceptable — a brief restart already provides a natural backoff for a legitimate attacker. """ WINDOW_SECS = 600 # rolling window: 10 minutes MAX_FAILURES = 10 # max failures before ban BAN_SECS = 3600 # ban duration: 1 hour def __init__(self) -> None: # ip → list of failure timestamps (epoch floats) self._failures: Dict[str, List[float]] = defaultdict(list) # ip → ban expiry timestamp self._ban_until: Dict[str, float] = {} # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def check(self, ip: str) -> None: """Raise HTTP 429 if the IP is banned or has exceeded the failure threshold. Call this *before* doing any credential work so the lockout is evaluated even when the request body is malformed. """ now = time.time() # Active ban? ban_expiry = self._ban_until.get(ip, 0) if ban_expiry > now: logger.warning("login_guard: blocked request from banned ip=%s (ban expires in %.0fs)", ip, ban_expiry - now) raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many failed attempts. Try again later.", ) # Failure count within rolling window cutoff = now - self.WINDOW_SECS recent = [t for t in self._failures[ip] if t > cutoff] self._failures[ip] = recent # prune stale entries while we're here if len(recent) >= self.MAX_FAILURES: # Threshold just reached — apply ban now self._ban_until[ip] = now + self.BAN_SECS logger.warning( "login_guard: threshold reached, banning ip=%s for %ds", ip, self.BAN_SECS, ) raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many failed attempts. Try again later.", ) def record_failure(self, ip: str) -> None: """Record a failed login attempt for the given IP.""" now = time.time() cutoff = now - self.WINDOW_SECS recent = [t for t in self._failures[ip] if t > cutoff] recent.append(now) self._failures[ip] = recent count = len(recent) logger.warning("login_guard: failure #%d from ip=%s", count, ip) if count >= self.MAX_FAILURES: self._ban_until[ip] = now + self.BAN_SECS logger.warning( "login_guard: threshold reached, banning ip=%s for %ds", ip, self.BAN_SECS, ) def record_success(self, ip: str) -> None: """Clear failure history and any active ban for the given IP.""" self._failures.pop(ip, None) self._ban_until.pop(ip, None) # --------------------------------------------------------------------------- # Module-level singleton # --------------------------------------------------------------------------- _guard: BruteForceGuard | None = None def get_guard() -> BruteForceGuard: """Return the shared BruteForceGuard singleton (created on first call).""" global _guard if _guard is None: _guard = BruteForceGuard() return _guard