""" IP-based sliding-window rate limiter. Cloudflare-aware: uses CF-Connecting-IP → X-Forwarded-For → client.host to identify the real client IP. """ import time from collections import defaultdict from threading import Lock from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from app.config import settings def _get_client_ip(request: Request) -> str: """Extract real client IP behind Cloudflare / reverse proxy.""" # Cloudflare always sets this when proxying cf_ip = request.headers.get("cf-connecting-ip") if cf_ip: return cf_ip.strip() # Fallback: first entry in X-Forwarded-For (set by most reverse proxies) xff = request.headers.get("x-forwarded-for") if xff: return xff.split(",")[0].strip() # Last resort: direct connection IP return request.client.host if request.client else "unknown" # Module-level hits store so tests can clear it _hits_store: dict[str, list[float]] = defaultdict(list) class RateLimiter(BaseHTTPMiddleware): """ Sliding-window rate limiter keyed on client IP. Only applies to POST /api/v1/waitlist. Returns 429 with Retry-After header when exceeded. """ def __init__(self, app, per_minute: int = settings.RATE_LIMIT_PER_MINUTE): super().__init__(app) self.per_minute = per_minute self.window = 60 # seconds self._hits = _hits_store self._lock = Lock() def _prune(self, ip: str, now: float) -> None: cutoff = now - self.window self._hits[ip] = [t for t in self._hits[ip] if t > cutoff] async def dispatch(self, request: Request, call_next) -> Response: # Only rate-limit the waitlist POST endpoint if request.method != "POST" or request.url.path != "/api/v1/waitlist": return await call_next(request) ip = _get_client_ip(request) now = time.monotonic() with self._lock: self._prune(ip, now) if len(self._hits[ip]) >= self.per_minute: return JSONResponse( status_code=429, content={"detail": "Too many requests. Please try again later."}, headers={"Retry-After": str(self.window)}, ) self._hits[ip].append(now) return await call_next(request)