77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
"""
|
|
IP-based sliding-window rate limiter.
|
|
|
|
Cloudflare-aware: uses CF-Connecting-IP → X-Forwarded-For → client.host
|
|
to identify the real client IP.
|
|
"""
|
|
|
|
import time
|
|
from collections import defaultdict
|
|
from threading import Lock
|
|
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
|
|
from app.config import settings
|
|
|
|
|
|
def _get_client_ip(request: Request) -> str:
|
|
"""Extract real client IP behind Cloudflare / reverse proxy."""
|
|
# Cloudflare always sets this when proxying
|
|
cf_ip = request.headers.get("cf-connecting-ip")
|
|
if cf_ip:
|
|
return cf_ip.strip()
|
|
|
|
# Fallback: first entry in X-Forwarded-For (set by most reverse proxies)
|
|
xff = request.headers.get("x-forwarded-for")
|
|
if xff:
|
|
return xff.split(",")[0].strip()
|
|
|
|
# Last resort: direct connection IP
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
|
|
# Module-level hits store so tests can clear it
|
|
_hits_store: dict[str, list[float]] = defaultdict(list)
|
|
|
|
|
|
class RateLimiter(BaseHTTPMiddleware):
|
|
"""
|
|
Sliding-window rate limiter keyed on client IP.
|
|
|
|
Only applies to POST /api/v1/waitlist.
|
|
Returns 429 with Retry-After header when exceeded.
|
|
"""
|
|
|
|
def __init__(self, app, per_minute: int = settings.RATE_LIMIT_PER_MINUTE):
|
|
super().__init__(app)
|
|
self.per_minute = per_minute
|
|
self.window = 60 # seconds
|
|
self._hits = _hits_store
|
|
self._lock = Lock()
|
|
|
|
def _prune(self, ip: str, now: float) -> None:
|
|
cutoff = now - self.window
|
|
self._hits[ip] = [t for t in self._hits[ip] if t > cutoff]
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
# Only rate-limit the waitlist POST endpoint
|
|
if request.method != "POST" or request.url.path != "/api/v1/waitlist":
|
|
return await call_next(request)
|
|
|
|
ip = _get_client_ip(request)
|
|
now = time.monotonic()
|
|
|
|
with self._lock:
|
|
self._prune(ip, now)
|
|
if len(self._hits[ip]) >= self.per_minute:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"detail": "Too many requests. Please try again later."},
|
|
headers={"Retry-After": str(self.window)},
|
|
)
|
|
self._hits[ip].append(now)
|
|
|
|
return await call_next(request)
|