step 9 complete: auth middleware, tier-aware rate limiter, and response sanitizer

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-02 22:18:17 +01:00
parent 4c4df7335a
commit 3e07fff958
8 changed files with 661 additions and 44 deletions

View File

@@ -0,0 +1,19 @@
"""API middleware package.
Exports the three middleware components introduced in Step 9:
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
- Sanitizer: ``SanitizerMiddleware``
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
from app.api.middleware.sanitizer import SanitizerMiddleware
__all__ = [
"get_current_user",
"oauth2_scheme",
"TierRateLimitMiddleware",
"limiter",
"SanitizerMiddleware",
]

View File

@@ -0,0 +1,51 @@
"""Auth middleware — JWT validation dependency.
``get_current_user`` is the FastAPI dependency used by all protected routes.
It decodes the Bearer JWT, validates signature and expiry, and returns a
``UserProfile`` carrying ``id``, ``email``, and ``tier``.
Exempt routes (no JWT required):
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from app.config.settings import settings
from app.schemas import UserProfile
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
Raises HTTP 401 on any invalid or expired token.
The tier embedded in the JWT is used for feature-gating until Step 12
adds a live DB lookup.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
tier: str = payload.get("tier", "free")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]

View File

@@ -0,0 +1,129 @@
"""Tier-aware rate limiting middleware.
Uses a per-user sliding-window counter (in-process, no Redis required).
The ``slowapi`` Limiter is also exported for optional route-level decoration.
Limits (requests per minute):
- free: 20
- pro: 60
- power: 120
- team: 200
Exempt paths bypass the limiter entirely:
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
- GET /api/v1/health
"""
from __future__ import annotations
import json
import time
from collections import defaultdict
from fastapi import Request, Response
from jose import JWTError, jwt
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.config.settings import settings
_TIER_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/billing/webhook",
"/api/v1/health",
}
)
def _get_user_id_from_jwt(request: Request) -> str:
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return get_remote_address(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
return payload.get("sub") or get_remote_address(request)
except JWTError:
return get_remote_address(request)
# Exported Limiter instance — available for optional route-level decoration.
limiter = Limiter(key_func=_get_user_id_from_jwt)
class TierRateLimitMiddleware(BaseHTTPMiddleware):
"""Sliding-window rate limiter applied globally across all non-exempt routes.
Each authenticated user gets their own 60-second window sized by tier.
Unauthenticated requests pass through (the auth dependency will reject them
with 401 before the route handler runs).
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
# user_id → list of request timestamps (float, seconds since epoch)
self._window: dict[str, list[float]] = defaultdict(list)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
if request.url.path in _EXEMPT_PATHS:
return await call_next(request)
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return await call_next(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str = payload.get("sub") or get_remote_address(request)
tier: str = payload.get("tier", "free")
except JWTError:
return await call_next(request)
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
now = time.monotonic()
window_start = now - 60.0
# Slide the window: discard timestamps older than 60 seconds.
timestamps = [t for t in self._window[user_id] if t > window_start]
if len(timestamps) >= limit:
retry_after = max(1, int(60 - (now - min(timestamps))))
return Response(
content=json.dumps(
{
"detail": (
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
f"Retry in {retry_after}s."
)
}
),
status_code=429,
headers={
"Retry-After": str(retry_after),
"Content-Type": "application/json",
},
)
timestamps.append(now)
self._window[user_id] = timestamps
return await call_next(request)

View File

@@ -0,0 +1,139 @@
"""Response sanitizer middleware.
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
that could reveal server-side prompt IP:
- System prompt openers ("You are a/an/the …")
- Agent routing metadata ("Available agents:", "intent classifier", …)
- LangChain tool schema fragments (``"type": "function"``)
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints
Binary responses (storage blobs, backup data) are never touched — the
middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified.
"""
from __future__ import annotations
import json
import logging
import re
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Detection patterns — order matters: fingerprints checked first (exact),
# then compiled regexes.
# ---------------------------------------------------------------------------
_FINGERPRINTS: tuple[str, ...] = (
"You are an intent classifier",
"Respond with just the agent name",
"Summarize these agent results",
"Available agents:",
"route to:",
)
_PATTERNS: tuple[re.Pattern[str], ...] = (
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
re.compile(r"Available agents\s*:", re.IGNORECASE),
re.compile(r"\bintent classifier\b", re.IGNORECASE),
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
re.compile(r"route\s+to\s*:", re.IGNORECASE),
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
)
def _sanitize_text(text: str) -> tuple[str, bool]:
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
Returns ``(cleaned_text, was_changed)``.
"""
# Fingerprint check — if any exact phrase is present, redact the whole string.
for fp in _FINGERPRINTS:
if fp in text:
return "[REDACTED]", True
changed = False
for pattern in _PATTERNS:
new_text, n = pattern.subn("[REDACTED]", text)
if n:
text = new_text
changed = True
return text, changed
class SanitizerMiddleware(BaseHTTPMiddleware):
"""Strip prompt IP from /api/v1/chat JSON responses."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
response: Response = await call_next(request)
# Only process chat endpoint responses.
if not request.url.path.startswith("/api/v1/chat"):
return response
# Read body — collect streaming chunks.
body_bytes = b""
async for chunk in response.body_iterator:
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
try:
body = json.loads(body_bytes.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
if not isinstance(body, dict):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Walk top-level string fields and sanitise.
sanitised_fields: list[str] = []
for key, value in body.items():
if isinstance(value, str):
cleaned, changed = _sanitize_text(value)
if changed:
body[key] = cleaned
sanitised_fields.append(key)
if sanitised_fields:
logger.warning(
"Sanitizer redacted prompt fragments",
extra={
"path": request.url.path,
"fields": sanitised_fields,
},
)
new_body = json.dumps(body).encode("utf-8")
headers = dict(response.headers)
headers["content-length"] = str(len(new_body))
return Response(
content=new_body,
status_code=response.status_code,
headers=headers,
media_type="application/json",
)