From 3e07fff958e6608e2796dacdd8c78768cfcc3716 Mon Sep 17 00:00:00 2001 From: roberto Date: Mon, 2 Mar 2026 22:18:17 +0100 Subject: [PATCH] step 9 complete: auth middleware, tier-aware rate limiter, and response sanitizer Co-Authored-By: Claude Sonnet 4.6 --- BACKEND_PLAN.md | 6 +- app/api/deps.py | 50 +---- app/api/middleware/__init__.py | 19 ++ app/api/middleware/auth.py | 51 ++++++ app/api/middleware/rate_limit.py | 129 +++++++++++++ app/api/middleware/sanitizer.py | 139 ++++++++++++++ app/main.py | 7 + tests/test_middleware.py | 304 +++++++++++++++++++++++++++++++ 8 files changed, 661 insertions(+), 44 deletions(-) create mode 100644 app/api/middleware/auth.py create mode 100644 app/api/middleware/rate_limit.py create mode 100644 app/api/middleware/sanitizer.py create mode 100644 tests/test_middleware.py diff --git a/BACKEND_PLAN.md b/BACKEND_PLAN.md index da95873..1ae707c 100644 --- a/BACKEND_PLAN.md +++ b/BACKEND_PLAN.md @@ -331,14 +331,14 @@ adiuva-api/ ### Step 9 — Middleware #### 9a — Auth middleware -- [ ] `app/api/middleware/auth.py`: +- [x] `app/api/middleware/auth.py`: - FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile` - Validates JWT signature, expiry, extracts `user_id` and `tier` - Raises `401` on invalid/expired token - Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook` #### 9b — Rate limiter -- [ ] `app/api/middleware/rate_limit.py`: +- [x] `app/api/middleware/rate_limit.py`: - Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)` - Tier-based limits: - Free: 20 req/min @@ -348,7 +348,7 @@ adiuva-api/ - Custom 429 response with `Retry-After` header #### 9c — Sanitizer -- [ ] `app/api/middleware/sanitizer.py`: +- [x] `app/api/middleware/sanitizer.py`: - Response middleware that scans response bodies - Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata - Pattern-based detection + exact match against known prompt fingerprints diff --git a/app/api/deps.py b/app/api/deps.py index a8fb393..0339d0d 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -1,46 +1,14 @@ """Shared FastAPI dependencies. -``get_current_user`` decodes the Bearer JWT and returns a ``UserProfile``. -Step 9 will layer rate-limiting and sanitization middleware on top of this. -Step 12 will add a DB look-up to fetch the live tier from PostgreSQL. +``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth`` +(the canonical location per Step 9). This module re-exports them so that all +existing route imports (``from app.api.deps import get_current_user``) continue +to work without modification. + +Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL +instead of reading it from the JWT payload. """ -from __future__ import annotations +from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401 -from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer -from jose import JWTError, jwt - -from app.config.settings import settings -from app.schemas import BillingTier, UserProfile - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") - - -async def get_current_user( - token: str = Depends(oauth2_scheme), -) -> UserProfile: - """Validate a Bearer JWT and return the authenticated user. - - Raises ``HTTP 401`` on any invalid or expired token. - The tier embedded in the JWT is used for feature-gating until Step 12 - adds a live DB lookup. - """ - credentials_exc = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode( - token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] - ) - user_id: str | None = payload.get("sub") - email: str | None = payload.get("email") - tier: str = payload.get("tier", "free") - if not user_id or not email: - raise credentials_exc - except JWTError: - raise credentials_exc - - return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type] +__all__ = ["get_current_user", "oauth2_scheme"] diff --git a/app/api/middleware/__init__.py b/app/api/middleware/__init__.py index e69de29..f67fc41 100644 --- a/app/api/middleware/__init__.py +++ b/app/api/middleware/__init__.py @@ -0,0 +1,19 @@ +"""API middleware package. + +Exports the three middleware components introduced in Step 9: + - Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme`` + - Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter) + - Sanitizer: ``SanitizerMiddleware`` +""" + +from app.api.middleware.auth import get_current_user, oauth2_scheme +from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter +from app.api.middleware.sanitizer import SanitizerMiddleware + +__all__ = [ + "get_current_user", + "oauth2_scheme", + "TierRateLimitMiddleware", + "limiter", + "SanitizerMiddleware", +] diff --git a/app/api/middleware/auth.py b/app/api/middleware/auth.py new file mode 100644 index 0000000..b596121 --- /dev/null +++ b/app/api/middleware/auth.py @@ -0,0 +1,51 @@ +"""Auth middleware — JWT validation dependency. + +``get_current_user`` is the FastAPI dependency used by all protected routes. +It decodes the Bearer JWT, validates signature and expiry, and returns a +``UserProfile`` carrying ``id``, ``email``, and ``tier``. + +Exempt routes (no JWT required): + - POST /api/v1/auth/register + - POST /api/v1/auth/login + - POST /api/v1/billing/webhook +""" + +from __future__ import annotations + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt + +from app.config.settings import settings +from app.schemas import UserProfile + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + + +async def get_current_user( + token: str = Depends(oauth2_scheme), +) -> UserProfile: + """Validate a Bearer JWT and return the authenticated user. + + Raises HTTP 401 on any invalid or expired token. + The tier embedded in the JWT is used for feature-gating until Step 12 + adds a live DB lookup. + """ + credentials_exc = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode( + token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] + ) + user_id: str | None = payload.get("sub") + email: str | None = payload.get("email") + tier: str = payload.get("tier", "free") + if not user_id or not email: + raise credentials_exc + except JWTError: + raise credentials_exc + + return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type] diff --git a/app/api/middleware/rate_limit.py b/app/api/middleware/rate_limit.py new file mode 100644 index 0000000..4a2af76 --- /dev/null +++ b/app/api/middleware/rate_limit.py @@ -0,0 +1,129 @@ +"""Tier-aware rate limiting middleware. + +Uses a per-user sliding-window counter (in-process, no Redis required). +The ``slowapi`` Limiter is also exported for optional route-level decoration. + +Limits (requests per minute): + - free: 20 + - pro: 60 + - power: 120 + - team: 200 + +Exempt paths bypass the limiter entirely: + - POST /api/v1/auth/register + - POST /api/v1/auth/login + - POST /api/v1/billing/webhook + - GET /api/v1/health +""" + +from __future__ import annotations + +import json +import time +from collections import defaultdict + +from fastapi import Request, Response +from jose import JWTError, jwt +from slowapi import Limiter +from slowapi.util import get_remote_address +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp + +from app.config.settings import settings + +_TIER_LIMITS: dict[str, int] = { + "free": 20, + "pro": 60, + "power": 120, + "team": 200, +} + +_EXEMPT_PATHS: frozenset[str] = frozenset( + { + "/api/v1/auth/register", + "/api/v1/auth/login", + "/api/v1/billing/webhook", + "/api/v1/health", + } +) + + +def _get_user_id_from_jwt(request: Request) -> str: + """Key function for the slowapi Limiter: returns JWT sub or remote IP.""" + auth = request.headers.get("Authorization", "") + token = auth.removeprefix("Bearer ").strip() + if not token: + return get_remote_address(request) + try: + payload = jwt.decode( + token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] + ) + return payload.get("sub") or get_remote_address(request) + except JWTError: + return get_remote_address(request) + + +# Exported Limiter instance — available for optional route-level decoration. +limiter = Limiter(key_func=_get_user_id_from_jwt) + + +class TierRateLimitMiddleware(BaseHTTPMiddleware): + """Sliding-window rate limiter applied globally across all non-exempt routes. + + Each authenticated user gets their own 60-second window sized by tier. + Unauthenticated requests pass through (the auth dependency will reject them + with 401 before the route handler runs). + """ + + def __init__(self, app: ASGIApp) -> None: + super().__init__(app) + # user_id → list of request timestamps (float, seconds since epoch) + self._window: dict[str, list[float]] = defaultdict(list) + + async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override] + if request.url.path in _EXEMPT_PATHS: + return await call_next(request) + + # Extract JWT claims — if no valid token, pass through for auth dep to handle. + auth = request.headers.get("Authorization", "") + token = auth.removeprefix("Bearer ").strip() + if not token: + return await call_next(request) + + try: + payload = jwt.decode( + token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] + ) + user_id: str = payload.get("sub") or get_remote_address(request) + tier: str = payload.get("tier", "free") + except JWTError: + return await call_next(request) + + limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"]) + now = time.monotonic() + window_start = now - 60.0 + + # Slide the window: discard timestamps older than 60 seconds. + timestamps = [t for t in self._window[user_id] if t > window_start] + + if len(timestamps) >= limit: + retry_after = max(1, int(60 - (now - min(timestamps)))) + return Response( + content=json.dumps( + { + "detail": ( + f"Rate limit exceeded ({limit} req/min for {tier} tier). " + f"Retry in {retry_after}s." + ) + } + ), + status_code=429, + headers={ + "Retry-After": str(retry_after), + "Content-Type": "application/json", + }, + ) + + timestamps.append(now) + self._window[user_id] = timestamps + return await call_next(request) diff --git a/app/api/middleware/sanitizer.py b/app/api/middleware/sanitizer.py new file mode 100644 index 0000000..570937f --- /dev/null +++ b/app/api/middleware/sanitizer.py @@ -0,0 +1,139 @@ +"""Response sanitizer middleware. + +Scans JSON responses from the /api/v1/chat endpoint and strips any fragments +that could reveal server-side prompt IP: + - System prompt openers ("You are a/an/the …") + - Agent routing metadata ("Available agents:", "intent classifier", …) + - LangChain tool schema fragments (``"type": "function"``) + - Internal reasoning markers (, , [INST], …) + - Exact-match known prompt fingerprints + +Binary responses (storage blobs, backup data) are never touched — the +middleware only activates for paths under /api/v1/chat. + +Any sanitisation event is logged as a WARNING with the request path and the +names of the fields that were modified. +""" + +from __future__ import annotations + +import json +import logging +import re + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Detection patterns — order matters: fingerprints checked first (exact), +# then compiled regexes. +# --------------------------------------------------------------------------- + +_FINGERPRINTS: tuple[str, ...] = ( + "You are an intent classifier", + "Respond with just the agent name", + "Summarize these agent results", + "Available agents:", + "route to:", +) + +_PATTERNS: tuple[re.Pattern[str], ...] = ( + re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL), + re.compile(r"Available agents\s*:", re.IGNORECASE), + re.compile(r"\bintent classifier\b", re.IGNORECASE), + re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema + re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE), + re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers + re.compile(r"route\s+to\s*:", re.IGNORECASE), + re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE), +) + + +def _sanitize_text(text: str) -> tuple[str, bool]: + """Scan *text* for prompt fragments and replace matches with ``[REDACTED]``. + + Returns ``(cleaned_text, was_changed)``. + """ + # Fingerprint check — if any exact phrase is present, redact the whole string. + for fp in _FINGERPRINTS: + if fp in text: + return "[REDACTED]", True + + changed = False + for pattern in _PATTERNS: + new_text, n = pattern.subn("[REDACTED]", text) + if n: + text = new_text + changed = True + + return text, changed + + +class SanitizerMiddleware(BaseHTTPMiddleware): + """Strip prompt IP from /api/v1/chat JSON responses.""" + + def __init__(self, app: ASGIApp) -> None: + super().__init__(app) + + async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override] + response: Response = await call_next(request) + + # Only process chat endpoint responses. + if not request.url.path.startswith("/api/v1/chat"): + return response + + # Read body — collect streaming chunks. + body_bytes = b"" + async for chunk in response.body_iterator: + body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode() + + # Skip non-JSON bodies (shouldn't happen on /chat, but be safe). + try: + body = json.loads(body_bytes.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + return Response( + content=body_bytes, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + + if not isinstance(body, dict): + return Response( + content=body_bytes, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + + # Walk top-level string fields and sanitise. + sanitised_fields: list[str] = [] + for key, value in body.items(): + if isinstance(value, str): + cleaned, changed = _sanitize_text(value) + if changed: + body[key] = cleaned + sanitised_fields.append(key) + + if sanitised_fields: + logger.warning( + "Sanitizer redacted prompt fragments", + extra={ + "path": request.url.path, + "fields": sanitised_fields, + }, + ) + + new_body = json.dumps(body).encode("utf-8") + headers = dict(response.headers) + headers["content-length"] = str(len(new_body)) + + return Response( + content=new_body, + status_code=response.status_code, + headers=headers, + media_type="application/json", + ) diff --git a/app/main.py b/app/main.py index 30f42b8..8db1a20 100644 --- a/app/main.py +++ b/app/main.py @@ -3,6 +3,8 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from app.api.middleware.rate_limit import TierRateLimitMiddleware +from app.api.middleware.sanitizer import SanitizerMiddleware from app.config.settings import settings @@ -33,6 +35,11 @@ def create_app() -> FastAPI: allow_methods=["*"], allow_headers=["*"], ) + # Middleware stack (Starlette inserts at position 0, so last-added = outermost). + # Request flow: TierRateLimit → Sanitizer → CORS → Router + # Response flow: Router → CORS → Sanitizer → TierRateLimit + app.add_middleware(SanitizerMiddleware) + app.add_middleware(TierRateLimitMiddleware) from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..343a171 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,304 @@ +"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer. + +Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT). +Rate limit: use unique user UUIDs per test so windows are independent; + the free-tier threshold (20 req/min) is exercised directly. +Sanitizer: the orchestrator is mocked to inject controlled prompt + fragments, and the chat endpoint response body is inspected. +""" + +from __future__ import annotations + +import time +import uuid +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient +from jose import jwt + +from app.config.settings import settings +from app.main import app +from app.schemas import ChatResponse + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_CHAT_BODY = { + "message": "hello", + "context": { + "user_profile": {}, + "relevant_documents": [], + "recent_tasks": [], + "conversation_history": [], + }, + "execution_mode": "direct", +} + + +def _make_jwt( + *, + user_id: str | None = None, + email: str = "test@example.com", + tier: str = "free", + exp_offset: int = 3600, + secret: str | None = None, + include_sub: bool = True, +) -> str: + """Mint a test JWT signed with the configured (or custom) secret.""" + uid = user_id or str(uuid.uuid4()) + now = int(time.time()) + payload: dict = { + "email": email, + "tier": tier, + "exp": now + exp_offset, + "iat": now, + } + if include_sub: + payload["sub"] = uid + key = secret or settings.JWT_SECRET + return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM) + + +def _auth_header(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +# --------------------------------------------------------------------------- +# Auth middleware +# --------------------------------------------------------------------------- + + +class TestAuthMiddleware: + """Tests exercised via GET /api/v1/auth/me.""" + + def test_valid_token_returns_profile(self) -> None: + uid = str(uuid.uuid4()) + token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro") + with TestClient(app) as client: + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == uid + assert data["email"] == "alice@example.com" + assert data["tier"] == "pro" + + def test_missing_token_returns_401(self) -> None: + with TestClient(app) as client: + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 + + def test_expired_token_returns_401(self) -> None: + token = _make_jwt(exp_offset=-1) # already expired + with TestClient(app) as client: + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 401 + + def test_wrong_signature_returns_401(self) -> None: + token = _make_jwt(secret="totally-wrong-secret") + with TestClient(app) as client: + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 401 + + def test_missing_sub_claim_returns_401(self) -> None: + token = _make_jwt(include_sub=False) + with TestClient(app) as client: + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 401 + + def test_malformed_token_returns_401(self) -> None: + with TestClient(app) as client: + resp = client.get( + "/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"} + ) + assert resp.status_code == 401 + + +# --------------------------------------------------------------------------- +# Rate limiter middleware +# --------------------------------------------------------------------------- + + +class TestRateLimitMiddleware: + """Each test uses a fresh unique user_id so windows never collide.""" + + def _unique_token(self, tier: str = "free") -> str: + return _make_jwt(user_id=str(uuid.uuid4()), tier=tier) + + def test_free_tier_allows_up_to_20_requests(self) -> None: + token = self._unique_token("free") + with TestClient(app) as client: + for _ in range(20): + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 200 + + def test_free_tier_blocks_21st_request(self) -> None: + token = self._unique_token("free") + with TestClient(app) as client: + for _ in range(20): + client.get("/api/v1/auth/me", headers=_auth_header(token)) + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 429 + + def test_429_includes_retry_after_header(self) -> None: + token = self._unique_token("free") + with TestClient(app) as client: + for _ in range(20): + client.get("/api/v1/auth/me", headers=_auth_header(token)) + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 429 + assert "retry-after" in resp.headers + retry_after = int(resp.headers["retry-after"]) + assert retry_after >= 1 + + def test_429_response_has_detail_field(self) -> None: + token = self._unique_token("free") + with TestClient(app) as client: + for _ in range(20): + client.get("/api/v1/auth/me", headers=_auth_header(token)) + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 429 + assert "detail" in resp.json() + + def test_pro_tier_allows_60_requests(self) -> None: + token = self._unique_token("pro") + with TestClient(app) as client: + # Sample: first 60 succeed, 61st is blocked. + for _ in range(60): + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 200 + resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) + assert resp.status_code == 429 + + def test_independent_users_have_separate_windows(self) -> None: + token_a = self._unique_token("free") + token_b = self._unique_token("free") + with TestClient(app) as client: + # Exhaust user A's quota. + for _ in range(20): + client.get("/api/v1/auth/me", headers=_auth_header(token_a)) + assert ( + client.get( + "/api/v1/auth/me", headers=_auth_header(token_a) + ).status_code + == 429 + ) + # User B's quota is untouched. + resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b)) + assert resp_b.status_code == 200 + + def test_exempt_path_register_never_rate_limited(self) -> None: + """POST /auth/register is exempt — 25 calls should never return 429.""" + with TestClient(app) as client: + for i in range(25): + resp = client.post( + "/api/v1/auth/register", + json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"}, + ) + # 201 on first, 409 on email collision — but never 429. + assert resp.status_code != 429 + + def test_exempt_path_login_never_rate_limited(self) -> None: + """POST /auth/login is exempt — multiple failed attempts are not rate-limited.""" + with TestClient(app) as client: + for _ in range(25): + resp = client.post( + "/api/v1/auth/login", + json={"email": "nosuchuser@example.com", "password": "wrong"}, + ) + assert resp.status_code != 429 + + def test_exempt_path_health_never_rate_limited(self) -> None: + with TestClient(app) as client: + for _ in range(25): + resp = client.get("/api/v1/health") + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Sanitizer middleware +# --------------------------------------------------------------------------- + + +class TestSanitizerMiddleware: + """Mock ``orchestrate`` to inject controlled strings into chat responses.""" + + _CHAT_PATH = "/api/v1/chat" + + def _token(self) -> str: + return _make_jwt(user_id=str(uuid.uuid4()), tier="pro") + + def _post_chat(self, client: TestClient, response_text: str) -> dict: + mock_response = ChatResponse(response=response_text, actions=[]) + with patch( + "app.api.routes.chat.orchestrate", + new_callable=AsyncMock, + return_value=mock_response, + ): + resp = client.post( + self._CHAT_PATH, + json=_CHAT_BODY, + headers=_auth_header(self._token()), + ) + assert resp.status_code == 200 + return resp.json() + + def test_clean_response_passes_through_unchanged(self) -> None: + with TestClient(app) as client: + data = self._post_chat(client, "Sure, I created the task for you.") + assert data["response"] == "Sure, I created the task for you." + + def test_strips_system_prompt_opener(self) -> None: + with TestClient(app) as client: + data = self._post_chat( + client, "You are an intent classifier. Route to task_agent." + ) + assert "You are" not in data["response"] + assert "[REDACTED]" in data["response"] + + def test_strips_known_fingerprint(self) -> None: + with TestClient(app) as client: + data = self._post_chat( + client, "Respond with just the agent name and nothing else." + ) + assert data["response"] == "[REDACTED]" + + def test_strips_tool_schema_fragment(self) -> None: + with TestClient(app) as client: + data = self._post_chat( + client, 'Here is the schema: {"type": "function", "name": "foo"}' + ) + assert '"type": "function"' not in data["response"] + + def test_strips_reasoning_tag(self) -> None: + with TestClient(app) as client: + data = self._post_chat( + client, "I should route this to calendar_agentDone." + ) + assert "" not in data["response"] + assert "[REDACTED]" in data["response"] + + def test_strips_available_agents_fragment(self) -> None: + with TestClient(app) as client: + data = self._post_chat( + client, "Available agents: task_agent, calendar_agent" + ) + assert "[REDACTED]" in data["response"] + + def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None: + """GET /api/v1/plans/playbook should pass through the sanitizer untouched.""" + token = self._token() + with TestClient(app) as client: + resp = client.get( + "/api/v1/plans/playbook", + headers=_auth_header(token), + ) + # The sanitizer should not interfere — just check it returns something + # (200 or whatever the route returns; we only care it's not broken). + assert resp.status_code in (200, 401, 403, 404) + + def test_sanitizer_preserves_empty_response(self) -> None: + with TestClient(app) as client: + data = self._post_chat(client, "") + assert data["response"] == ""