130 lines
4.1 KiB
Python
130 lines
4.1 KiB
Python
"""Tier-aware rate limiting middleware.
|
|
|
|
Uses a per-user sliding-window counter (in-process, no Redis required).
|
|
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
|
|
|
Limits (requests per minute):
|
|
- free: 20
|
|
- pro: 60
|
|
- power: 120
|
|
- team: 200
|
|
|
|
Exempt paths bypass the limiter entirely:
|
|
- POST /api/v1/auth/register
|
|
- POST /api/v1/auth/login
|
|
- POST /api/v1/billing/webhook
|
|
- GET /api/v1/health
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
from fastapi import Request, Response
|
|
from jose import JWTError, jwt
|
|
from slowapi import Limiter
|
|
from slowapi.util import get_remote_address
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.types import ASGIApp
|
|
|
|
from app.config.settings import settings
|
|
|
|
_TIER_LIMITS: dict[str, int] = {
|
|
"free": 20,
|
|
"pro": 60,
|
|
"power": 120,
|
|
"team": 200,
|
|
}
|
|
|
|
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
|
{
|
|
"/api/v1/auth/register",
|
|
"/api/v1/auth/login",
|
|
"/api/v1/billing/webhook",
|
|
"/api/v1/health",
|
|
}
|
|
)
|
|
|
|
|
|
def _get_user_id_from_jwt(request: Request) -> str:
|
|
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
|
auth = request.headers.get("Authorization", "")
|
|
token = auth.removeprefix("Bearer ").strip()
|
|
if not token:
|
|
return get_remote_address(request)
|
|
try:
|
|
payload = jwt.decode(
|
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
return payload.get("sub") or get_remote_address(request)
|
|
except JWTError:
|
|
return get_remote_address(request)
|
|
|
|
|
|
# Exported Limiter instance — available for optional route-level decoration.
|
|
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
|
|
|
|
|
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
|
|
|
Each authenticated user gets their own 60-second window sized by tier.
|
|
Unauthenticated requests pass through (the auth dependency will reject them
|
|
with 401 before the route handler runs).
|
|
"""
|
|
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
super().__init__(app)
|
|
# user_id → list of request timestamps (float, seconds since epoch)
|
|
self._window: dict[str, list[float]] = defaultdict(list)
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
|
if request.url.path in _EXEMPT_PATHS:
|
|
return await call_next(request)
|
|
|
|
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
|
auth = request.headers.get("Authorization", "")
|
|
token = auth.removeprefix("Bearer ").strip()
|
|
if not token:
|
|
return await call_next(request)
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
user_id: str = payload.get("sub") or get_remote_address(request)
|
|
tier: str = payload.get("tier", "free")
|
|
except JWTError:
|
|
return await call_next(request)
|
|
|
|
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
|
now = time.monotonic()
|
|
window_start = now - 60.0
|
|
|
|
# Slide the window: discard timestamps older than 60 seconds.
|
|
timestamps = [t for t in self._window[user_id] if t > window_start]
|
|
|
|
if len(timestamps) >= limit:
|
|
retry_after = max(1, int(60 - (now - min(timestamps))))
|
|
return Response(
|
|
content=json.dumps(
|
|
{
|
|
"detail": (
|
|
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
|
f"Retry in {retry_after}s."
|
|
)
|
|
}
|
|
),
|
|
status_code=429,
|
|
headers={
|
|
"Retry-After": str(retry_after),
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
|
|
timestamps.append(now)
|
|
self._window[user_id] = timestamps
|
|
return await call_next(request)
|