"""Tier-aware rate limiting middleware. Uses a per-user sliding-window counter (in-process, no Redis required). The ``slowapi`` Limiter is also exported for optional route-level decoration. Limits (requests per minute): - free: 20 - pro: 60 - power: 120 - team: 200 Exempt paths bypass the limiter entirely: - POST /api/v1/auth/register - POST /api/v1/auth/login - POST /api/v1/billing/webhook - GET /api/v1/health """ from __future__ import annotations import json import time from collections import defaultdict from fastapi import Request, Response from jose import JWTError, jwt from slowapi import Limiter from slowapi.util import get_remote_address from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp from app.config.settings import settings _TIER_LIMITS: dict[str, int] = { "free": 20, "pro": 60, "power": 120, "team": 200, } _EXEMPT_PATHS: frozenset[str] = frozenset( { "/api/v1/auth/register", "/api/v1/auth/login", "/api/v1/billing/webhook", "/api/v1/health", } ) def _get_user_id_from_jwt(request: Request) -> str: """Key function for the slowapi Limiter: returns JWT sub or remote IP.""" auth = request.headers.get("Authorization", "") token = auth.removeprefix("Bearer ").strip() if not token: return get_remote_address(request) try: payload = jwt.decode( token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] ) return payload.get("sub") or get_remote_address(request) except JWTError: return get_remote_address(request) # Exported Limiter instance — available for optional route-level decoration. limiter = Limiter(key_func=_get_user_id_from_jwt) class TierRateLimitMiddleware(BaseHTTPMiddleware): """Sliding-window rate limiter applied globally across all non-exempt routes. Each authenticated user gets their own 60-second window sized by tier. Unauthenticated requests pass through (the auth dependency will reject them with 401 before the route handler runs). """ def __init__(self, app: ASGIApp) -> None: super().__init__(app) # user_id → list of request timestamps (float, seconds since epoch) self._window: dict[str, list[float]] = defaultdict(list) async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override] if request.url.path in _EXEMPT_PATHS: return await call_next(request) # Extract JWT claims — if no valid token, pass through for auth dep to handle. auth = request.headers.get("Authorization", "") token = auth.removeprefix("Bearer ").strip() if not token: return await call_next(request) try: payload = jwt.decode( token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] ) user_id: str = payload.get("sub") or get_remote_address(request) tier: str = payload.get("tier", "free") except JWTError: return await call_next(request) limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"]) now = time.monotonic() window_start = now - 60.0 # Slide the window: discard timestamps older than 60 seconds. timestamps = [t for t in self._window[user_id] if t > window_start] if len(timestamps) >= limit: retry_after = max(1, int(60 - (now - min(timestamps)))) return Response( content=json.dumps( { "detail": ( f"Rate limit exceeded ({limit} req/min for {tier} tier). " f"Retry in {retry_after}s." ) } ), status_code=429, headers={ "Retry-After": str(retry_after), "Content-Type": "application/json", }, ) timestamps.append(now) self._window[user_id] = timestamps return await call_next(request)