step 9 complete: auth middleware, tier-aware rate limiter, and response sanitizer
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -331,14 +331,14 @@ adiuva-api/
|
||||
### Step 9 — Middleware
|
||||
|
||||
#### 9a — Auth middleware
|
||||
- [ ] `app/api/middleware/auth.py`:
|
||||
- [x] `app/api/middleware/auth.py`:
|
||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
||||
- Raises `401` on invalid/expired token
|
||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||
|
||||
#### 9b — Rate limiter
|
||||
- [ ] `app/api/middleware/rate_limit.py`:
|
||||
- [x] `app/api/middleware/rate_limit.py`:
|
||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
||||
- Tier-based limits:
|
||||
- Free: 20 req/min
|
||||
@@ -348,7 +348,7 @@ adiuva-api/
|
||||
- Custom 429 response with `Retry-After` header
|
||||
|
||||
#### 9c — Sanitizer
|
||||
- [ ] `app/api/middleware/sanitizer.py`:
|
||||
- [x] `app/api/middleware/sanitizer.py`:
|
||||
- Response middleware that scans response bodies
|
||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
||||
- Pattern-based detection + exact match against known prompt fingerprints
|
||||
|
||||
@@ -1,46 +1,14 @@
|
||||
"""Shared FastAPI dependencies.
|
||||
|
||||
``get_current_user`` decodes the Bearer JWT and returns a ``UserProfile``.
|
||||
Step 9 will layer rate-limiting and sanitization middleware on top of this.
|
||||
Step 12 will add a DB look-up to fetch the live tier from PostgreSQL.
|
||||
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
|
||||
(the canonical location per Step 9). This module re-exports them so that all
|
||||
existing route imports (``from app.api.deps import get_current_user``) continue
|
||||
to work without modification.
|
||||
|
||||
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
|
||||
instead of reading it from the JWT payload.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
Raises ``HTTP 401`` on any invalid or expired token.
|
||||
The tier embedded in the JWT is used for feature-gating until Step 12
|
||||
adds a live DB lookup.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
tier: str = payload.get("tier", "free")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||
__all__ = ["get_current_user", "oauth2_scheme"]
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
"""API middleware package.
|
||||
|
||||
Exports the three middleware components introduced in Step 9:
|
||||
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
|
||||
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
|
||||
- Sanitizer: ``SanitizerMiddleware``
|
||||
"""
|
||||
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
|
||||
__all__ = [
|
||||
"get_current_user",
|
||||
"oauth2_scheme",
|
||||
"TierRateLimitMiddleware",
|
||||
"limiter",
|
||||
"SanitizerMiddleware",
|
||||
]
|
||||
|
||||
51
app/api/middleware/auth.py
Normal file
51
app/api/middleware/auth.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Auth middleware — JWT validation dependency.
|
||||
|
||||
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||
It decodes the Bearer JWT, validates signature and expiry, and returns a
|
||||
``UserProfile`` carrying ``id``, ``email``, and ``tier``.
|
||||
|
||||
Exempt routes (no JWT required):
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas import UserProfile
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
Raises HTTP 401 on any invalid or expired token.
|
||||
The tier embedded in the JWT is used for feature-gating until Step 12
|
||||
adds a live DB lookup.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
tier: str = payload.get("tier", "free")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||
129
app/api/middleware/rate_limit.py
Normal file
129
app/api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tier-aware rate limiting middleware.
|
||||
|
||||
Uses a per-user sliding-window counter (in-process, no Redis required).
|
||||
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
||||
|
||||
Limits (requests per minute):
|
||||
- free: 20
|
||||
- pro: 60
|
||||
- power: 120
|
||||
- team: 200
|
||||
|
||||
Exempt paths bypass the limiter entirely:
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
- GET /api/v1/health
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import Request, Response
|
||||
from jose import JWTError, jwt
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
_TIER_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/billing/webhook",
|
||||
"/api/v1/health",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_user_id_from_jwt(request: Request) -> str:
|
||||
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return get_remote_address(request)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
return payload.get("sub") or get_remote_address(request)
|
||||
except JWTError:
|
||||
return get_remote_address(request)
|
||||
|
||||
|
||||
# Exported Limiter instance — available for optional route-level decoration.
|
||||
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
||||
|
||||
|
||||
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
||||
|
||||
Each authenticated user gets their own 60-second window sized by tier.
|
||||
Unauthenticated requests pass through (the auth dependency will reject them
|
||||
with 401 before the route handler runs).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
# user_id → list of request timestamps (float, seconds since epoch)
|
||||
self._window: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
if request.url.path in _EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str = payload.get("sub") or get_remote_address(request)
|
||||
tier: str = payload.get("tier", "free")
|
||||
except JWTError:
|
||||
return await call_next(request)
|
||||
|
||||
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
||||
now = time.monotonic()
|
||||
window_start = now - 60.0
|
||||
|
||||
# Slide the window: discard timestamps older than 60 seconds.
|
||||
timestamps = [t for t in self._window[user_id] if t > window_start]
|
||||
|
||||
if len(timestamps) >= limit:
|
||||
retry_after = max(1, int(60 - (now - min(timestamps))))
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"detail": (
|
||||
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
||||
f"Retry in {retry_after}s."
|
||||
)
|
||||
}
|
||||
),
|
||||
status_code=429,
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
timestamps.append(now)
|
||||
self._window[user_id] = timestamps
|
||||
return await call_next(request)
|
||||
139
app/api/middleware/sanitizer.py
Normal file
139
app/api/middleware/sanitizer.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Response sanitizer middleware.
|
||||
|
||||
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
||||
that could reveal server-side prompt IP:
|
||||
- System prompt openers ("You are a/an/the …")
|
||||
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
||||
- LangChain tool schema fragments (``"type": "function"``)
|
||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||
- Exact-match known prompt fingerprints
|
||||
|
||||
Binary responses (storage blobs, backup data) are never touched — the
|
||||
middleware only activates for paths under /api/v1/chat.
|
||||
|
||||
Any sanitisation event is logged as a WARNING with the request path and the
|
||||
names of the fields that were modified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection patterns — order matters: fingerprints checked first (exact),
|
||||
# then compiled regexes.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FINGERPRINTS: tuple[str, ...] = (
|
||||
"You are an intent classifier",
|
||||
"Respond with just the agent name",
|
||||
"Summarize these agent results",
|
||||
"Available agents:",
|
||||
"route to:",
|
||||
)
|
||||
|
||||
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||||
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
||||
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
||||
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
||||
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
||||
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
||||
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
||||
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_text(text: str) -> tuple[str, bool]:
|
||||
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
||||
|
||||
Returns ``(cleaned_text, was_changed)``.
|
||||
"""
|
||||
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
||||
for fp in _FINGERPRINTS:
|
||||
if fp in text:
|
||||
return "[REDACTED]", True
|
||||
|
||||
changed = False
|
||||
for pattern in _PATTERNS:
|
||||
new_text, n = pattern.subn("[REDACTED]", text)
|
||||
if n:
|
||||
text = new_text
|
||||
changed = True
|
||||
|
||||
return text, changed
|
||||
|
||||
|
||||
class SanitizerMiddleware(BaseHTTPMiddleware):
|
||||
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
response: Response = await call_next(request)
|
||||
|
||||
# Only process chat endpoint responses.
|
||||
if not request.url.path.startswith("/api/v1/chat"):
|
||||
return response
|
||||
|
||||
# Read body — collect streaming chunks.
|
||||
body_bytes = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||
|
||||
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
||||
try:
|
||||
body = json.loads(body_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
if not isinstance(body, dict):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
# Walk top-level string fields and sanitise.
|
||||
sanitised_fields: list[str] = []
|
||||
for key, value in body.items():
|
||||
if isinstance(value, str):
|
||||
cleaned, changed = _sanitize_text(value)
|
||||
if changed:
|
||||
body[key] = cleaned
|
||||
sanitised_fields.append(key)
|
||||
|
||||
if sanitised_fields:
|
||||
logger.warning(
|
||||
"Sanitizer redacted prompt fragments",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"fields": sanitised_fields,
|
||||
},
|
||||
)
|
||||
|
||||
new_body = json.dumps(body).encode("utf-8")
|
||||
headers = dict(response.headers)
|
||||
headers["content-length"] = str(len(new_body))
|
||||
|
||||
return Response(
|
||||
content=new_body,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type="application/json",
|
||||
)
|
||||
@@ -3,6 +3,8 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@@ -33,6 +35,11 @@ def create_app() -> FastAPI:
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
|
||||
# Request flow: TierRateLimit → Sanitizer → CORS → Router
|
||||
# Response flow: Router → CORS → Sanitizer → TierRateLimit
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
|
||||
|
||||
|
||||
304
tests/test_middleware.py
Normal file
304
tests/test_middleware.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer.
|
||||
|
||||
Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT).
|
||||
Rate limit: use unique user UUIDs per test so windows are independent;
|
||||
the free-tier threshold (20 req/min) is exercised directly.
|
||||
Sanitizer: the orchestrator is mocked to inject controlled prompt
|
||||
fragments, and the chat endpoint response body is inspected.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from jose import jwt
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.main import app
|
||||
from app.schemas import ChatResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CHAT_BODY = {
|
||||
"message": "hello",
|
||||
"context": {
|
||||
"user_profile": {},
|
||||
"relevant_documents": [],
|
||||
"recent_tasks": [],
|
||||
"conversation_history": [],
|
||||
},
|
||||
"execution_mode": "direct",
|
||||
}
|
||||
|
||||
|
||||
def _make_jwt(
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
email: str = "test@example.com",
|
||||
tier: str = "free",
|
||||
exp_offset: int = 3600,
|
||||
secret: str | None = None,
|
||||
include_sub: bool = True,
|
||||
) -> str:
|
||||
"""Mint a test JWT signed with the configured (or custom) secret."""
|
||||
uid = user_id or str(uuid.uuid4())
|
||||
now = int(time.time())
|
||||
payload: dict = {
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": now + exp_offset,
|
||||
"iat": now,
|
||||
}
|
||||
if include_sub:
|
||||
payload["sub"] = uid
|
||||
key = secret or settings.JWT_SECRET
|
||||
return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def _auth_header(token: str) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthMiddleware:
|
||||
"""Tests exercised via GET /api/v1/auth/me."""
|
||||
|
||||
def test_valid_token_returns_profile(self) -> None:
|
||||
uid = str(uuid.uuid4())
|
||||
token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro")
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == uid
|
||||
assert data["email"] == "alice@example.com"
|
||||
assert data["tier"] == "pro"
|
||||
|
||||
def test_missing_token_returns_401(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_expired_token_returns_401(self) -> None:
|
||||
token = _make_jwt(exp_offset=-1) # already expired
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_wrong_signature_returns_401(self) -> None:
|
||||
token = _make_jwt(secret="totally-wrong-secret")
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_missing_sub_claim_returns_401(self) -> None:
|
||||
token = _make_jwt(include_sub=False)
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_malformed_token_returns_401(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get(
|
||||
"/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate limiter middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Each test uses a fresh unique user_id so windows never collide."""
|
||||
|
||||
def _unique_token(self, tier: str = "free") -> str:
|
||||
return _make_jwt(user_id=str(uuid.uuid4()), tier=tier)
|
||||
|
||||
def test_free_tier_allows_up_to_20_requests(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_free_tier_blocks_21st_request(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_429_includes_retry_after_header(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
assert "retry-after" in resp.headers
|
||||
retry_after = int(resp.headers["retry-after"])
|
||||
assert retry_after >= 1
|
||||
|
||||
def test_429_response_has_detail_field(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
assert "detail" in resp.json()
|
||||
|
||||
def test_pro_tier_allows_60_requests(self) -> None:
|
||||
token = self._unique_token("pro")
|
||||
with TestClient(app) as client:
|
||||
# Sample: first 60 succeed, 61st is blocked.
|
||||
for _ in range(60):
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_independent_users_have_separate_windows(self) -> None:
|
||||
token_a = self._unique_token("free")
|
||||
token_b = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
# Exhaust user A's quota.
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token_a))
|
||||
assert (
|
||||
client.get(
|
||||
"/api/v1/auth/me", headers=_auth_header(token_a)
|
||||
).status_code
|
||||
== 429
|
||||
)
|
||||
# User B's quota is untouched.
|
||||
resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b))
|
||||
assert resp_b.status_code == 200
|
||||
|
||||
def test_exempt_path_register_never_rate_limited(self) -> None:
|
||||
"""POST /auth/register is exempt — 25 calls should never return 429."""
|
||||
with TestClient(app) as client:
|
||||
for i in range(25):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"},
|
||||
)
|
||||
# 201 on first, 409 on email collision — but never 429.
|
||||
assert resp.status_code != 429
|
||||
|
||||
def test_exempt_path_login_never_rate_limited(self) -> None:
|
||||
"""POST /auth/login is exempt — multiple failed attempts are not rate-limited."""
|
||||
with TestClient(app) as client:
|
||||
for _ in range(25):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "nosuchuser@example.com", "password": "wrong"},
|
||||
)
|
||||
assert resp.status_code != 429
|
||||
|
||||
def test_exempt_path_health_never_rate_limited(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
for _ in range(25):
|
||||
resp = client.get("/api/v1/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sanitizer middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizerMiddleware:
|
||||
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
|
||||
|
||||
_CHAT_PATH = "/api/v1/chat"
|
||||
|
||||
def _token(self) -> str:
|
||||
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
||||
|
||||
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
||||
mock_response = ChatResponse(response=response_text, actions=[])
|
||||
with patch(
|
||||
"app.api.routes.chat.orchestrate",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
resp = client.post(
|
||||
self._CHAT_PATH,
|
||||
json=_CHAT_BODY,
|
||||
headers=_auth_header(self._token()),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
return resp.json()
|
||||
|
||||
def test_clean_response_passes_through_unchanged(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(client, "Sure, I created the task for you.")
|
||||
assert data["response"] == "Sure, I created the task for you."
|
||||
|
||||
def test_strips_system_prompt_opener(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "You are an intent classifier. Route to task_agent."
|
||||
)
|
||||
assert "You are" not in data["response"]
|
||||
assert "[REDACTED]" in data["response"]
|
||||
|
||||
def test_strips_known_fingerprint(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "Respond with just the agent name and nothing else."
|
||||
)
|
||||
assert data["response"] == "[REDACTED]"
|
||||
|
||||
def test_strips_tool_schema_fragment(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, 'Here is the schema: {"type": "function", "name": "foo"}'
|
||||
)
|
||||
assert '"type": "function"' not in data["response"]
|
||||
|
||||
def test_strips_reasoning_tag(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "<thinking>I should route this to calendar_agent</thinking>Done."
|
||||
)
|
||||
assert "<thinking>" not in data["response"]
|
||||
assert "[REDACTED]" in data["response"]
|
||||
|
||||
def test_strips_available_agents_fragment(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "Available agents: task_agent, calendar_agent"
|
||||
)
|
||||
assert "[REDACTED]" in data["response"]
|
||||
|
||||
def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None:
|
||||
"""GET /api/v1/plans/playbook should pass through the sanitizer untouched."""
|
||||
token = self._token()
|
||||
with TestClient(app) as client:
|
||||
resp = client.get(
|
||||
"/api/v1/plans/playbook",
|
||||
headers=_auth_header(token),
|
||||
)
|
||||
# The sanitizer should not interfere — just check it returns something
|
||||
# (200 or whatever the route returns; we only care it's not broken).
|
||||
assert resp.status_code in (200, 401, 403, 404)
|
||||
|
||||
def test_sanitizer_preserves_empty_response(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(client, "")
|
||||
assert data["response"] == ""
|
||||
Reference in New Issue
Block a user