Initial commit: waitlist microservice
This commit is contained in:
76
app/rate_limit.py
Normal file
76
app/rate_limit.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
IP-based sliding-window rate limiter.
|
||||
|
||||
Cloudflare-aware: uses CF-Connecting-IP → X-Forwarded-For → client.host
|
||||
to identify the real client IP.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract real client IP behind Cloudflare / reverse proxy."""
|
||||
# Cloudflare always sets this when proxying
|
||||
cf_ip = request.headers.get("cf-connecting-ip")
|
||||
if cf_ip:
|
||||
return cf_ip.strip()
|
||||
|
||||
# Fallback: first entry in X-Forwarded-For (set by most reverse proxies)
|
||||
xff = request.headers.get("x-forwarded-for")
|
||||
if xff:
|
||||
return xff.split(",")[0].strip()
|
||||
|
||||
# Last resort: direct connection IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
# Module-level hits store so tests can clear it
|
||||
_hits_store: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
|
||||
class RateLimiter(BaseHTTPMiddleware):
|
||||
"""
|
||||
Sliding-window rate limiter keyed on client IP.
|
||||
|
||||
Only applies to POST /api/v1/waitlist.
|
||||
Returns 429 with Retry-After header when exceeded.
|
||||
"""
|
||||
|
||||
def __init__(self, app, per_minute: int = settings.RATE_LIMIT_PER_MINUTE):
|
||||
super().__init__(app)
|
||||
self.per_minute = per_minute
|
||||
self.window = 60 # seconds
|
||||
self._hits = _hits_store
|
||||
self._lock = Lock()
|
||||
|
||||
def _prune(self, ip: str, now: float) -> None:
|
||||
cutoff = now - self.window
|
||||
self._hits[ip] = [t for t in self._hits[ip] if t > cutoff]
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
# Only rate-limit the waitlist POST endpoint
|
||||
if request.method != "POST" or request.url.path != "/api/v1/waitlist":
|
||||
return await call_next(request)
|
||||
|
||||
ip = _get_client_ip(request)
|
||||
now = time.monotonic()
|
||||
|
||||
with self._lock:
|
||||
self._prune(ip, now)
|
||||
if len(self._hits[ip]) >= self.per_minute:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Too many requests. Please try again later."},
|
||||
headers={"Retry-After": str(self.window)},
|
||||
)
|
||||
self._hits[ip].append(now)
|
||||
|
||||
return await call_next(request)
|
||||
Reference in New Issue
Block a user