"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer. Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT). Rate limit: use unique user UUIDs per test so windows are independent; the free-tier threshold (20 req/min) is exercised directly. Sanitizer: the orchestrator is mocked to inject controlled prompt fragments, and the chat endpoint response body is inspected. """ from __future__ import annotations import time import uuid from unittest.mock import AsyncMock, patch import pytest from fastapi.testclient import TestClient from jose import jwt from app.config.settings import settings from app.db import get_session from app.main import app from app.schemas import ChatResponse from tests.conftest import TEST_USER_IDS # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # Autouse: redirect all DB access to the in-memory SQLite test engine. # --------------------------------------------------------------------------- @pytest.fixture(autouse=True) def _override_db(db_session): """Route all get_session calls to the test SQLite session.""" async def _gen(): yield db_session app.dependency_overrides[get_session] = _gen yield app.dependency_overrides.pop(get_session, None) _CHAT_BODY = { "message": "hello", "context": { "user_profile": {}, "relevant_documents": [], "recent_tasks": [], "conversation_history": [], }, "execution_mode": "direct", } def _make_jwt( *, user_id: str | None = None, email: str = "test@example.com", tier: str = "free", exp_offset: int = 3600, secret: str | None = None, include_sub: bool = True, ) -> str: """Mint a test JWT signed with the configured (or custom) secret.""" uid = user_id or str(uuid.uuid4()) now = int(time.time()) payload: dict = { "email": email, "tier": tier, "exp": now + exp_offset, "iat": now, } if include_sub: payload["sub"] = uid key = secret or settings.JWT_SECRET return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM) def _auth_header(token: str) -> dict[str, str]: return {"Authorization": f"Bearer {token}"} # --------------------------------------------------------------------------- # Auth middleware # --------------------------------------------------------------------------- class TestAuthMiddleware: """Tests exercised via GET /api/v1/auth/me.""" def test_valid_token_returns_profile(self) -> None: # Use the seeded pro user so the subscription lookup returns 'pro'. uid = TEST_USER_IDS["pro"] token = _make_jwt(user_id=uid, email="pro@test.com", tier="pro") with TestClient(app) as client: resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 200 data = resp.json() assert data["id"] == uid assert data["email"] == "pro@test.com" assert data["tier"] == "pro" def test_missing_token_returns_401(self) -> None: with TestClient(app) as client: resp = client.get("/api/v1/auth/me") assert resp.status_code == 401 def test_expired_token_returns_401(self) -> None: token = _make_jwt(exp_offset=-1) # already expired with TestClient(app) as client: resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 401 def test_wrong_signature_returns_401(self) -> None: token = _make_jwt(secret="totally-wrong-secret") with TestClient(app) as client: resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 401 def test_missing_sub_claim_returns_401(self) -> None: token = _make_jwt(include_sub=False) with TestClient(app) as client: resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 401 def test_malformed_token_returns_401(self) -> None: with TestClient(app) as client: resp = client.get( "/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"} ) assert resp.status_code == 401 # --------------------------------------------------------------------------- # Rate limiter middleware # --------------------------------------------------------------------------- class TestRateLimitMiddleware: """Each test uses a fresh unique user_id so windows never collide.""" def _unique_token(self, tier: str = "free") -> str: return _make_jwt(user_id=str(uuid.uuid4()), tier=tier) def test_free_tier_allows_up_to_20_requests(self) -> None: token = self._unique_token("free") with TestClient(app) as client: for _ in range(20): resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 200 def test_free_tier_blocks_21st_request(self) -> None: token = self._unique_token("free") with TestClient(app) as client: for _ in range(20): client.get("/api/v1/auth/me", headers=_auth_header(token)) resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 429 def test_429_includes_retry_after_header(self) -> None: token = self._unique_token("free") with TestClient(app) as client: for _ in range(20): client.get("/api/v1/auth/me", headers=_auth_header(token)) resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 429 assert "retry-after" in resp.headers retry_after = int(resp.headers["retry-after"]) assert retry_after >= 1 def test_429_response_has_detail_field(self) -> None: token = self._unique_token("free") with TestClient(app) as client: for _ in range(20): client.get("/api/v1/auth/me", headers=_auth_header(token)) resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 429 assert "detail" in resp.json() def test_pro_tier_allows_60_requests(self) -> None: token = self._unique_token("pro") with TestClient(app) as client: # Sample: first 60 succeed, 61st is blocked. for _ in range(60): resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 200 resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 429 def test_independent_users_have_separate_windows(self) -> None: token_a = self._unique_token("free") token_b = self._unique_token("free") with TestClient(app) as client: # Exhaust user A's quota. for _ in range(20): client.get("/api/v1/auth/me", headers=_auth_header(token_a)) assert ( client.get( "/api/v1/auth/me", headers=_auth_header(token_a) ).status_code == 429 ) # User B's quota is untouched. resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b)) assert resp_b.status_code == 200 def test_exempt_path_register_never_rate_limited(self) -> None: """POST /auth/register is exempt — 25 calls should never return 429.""" with TestClient(app) as client: for i in range(25): resp = client.post( "/api/v1/auth/register", json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"}, ) # 201 on first, 409 on email collision — but never 429. assert resp.status_code != 429 def test_exempt_path_login_never_rate_limited(self) -> None: """POST /auth/login is exempt — multiple failed attempts are not rate-limited.""" with TestClient(app) as client: for _ in range(25): resp = client.post( "/api/v1/auth/login", json={"email": "nosuchuser@example.com", "password": "wrong"}, ) assert resp.status_code != 429 def test_exempt_path_health_never_rate_limited(self) -> None: with TestClient(app) as client: for _ in range(25): resp = client.get("/api/v1/health") assert resp.status_code == 200 # --------------------------------------------------------------------------- # Sanitizer middleware # --------------------------------------------------------------------------- class TestSanitizerMiddleware: """Mock ``orchestrate`` to inject controlled strings into chat responses.""" _CHAT_PATH = "/api/v1/chat" def _token(self) -> str: return _make_jwt(user_id=str(uuid.uuid4()), tier="pro") def _post_chat(self, client: TestClient, response_text: str) -> dict: mock_response = ChatResponse(response=response_text, actions=[]) with patch( "app.api.routes.chat.orchestrate", new_callable=AsyncMock, return_value=mock_response, ): resp = client.post( self._CHAT_PATH, json=_CHAT_BODY, headers=_auth_header(self._token()), ) assert resp.status_code == 200 return resp.json() def test_clean_response_passes_through_unchanged(self) -> None: with TestClient(app) as client: data = self._post_chat(client, "Sure, I created the task for you.") assert data["response"] == "Sure, I created the task for you." def test_strips_system_prompt_opener(self) -> None: with TestClient(app) as client: data = self._post_chat( client, "You are an intent classifier. Route to task_agent." ) assert "You are" not in data["response"] assert "[REDACTED]" in data["response"] def test_strips_known_fingerprint(self) -> None: with TestClient(app) as client: data = self._post_chat( client, "Respond with just the agent name and nothing else." ) assert data["response"] == "[REDACTED]" def test_strips_tool_schema_fragment(self) -> None: with TestClient(app) as client: data = self._post_chat( client, 'Here is the schema: {"type": "function", "name": "foo"}' ) assert '"type": "function"' not in data["response"] def test_strips_reasoning_tag(self) -> None: with TestClient(app) as client: data = self._post_chat( client, "I should route this to calendar_agentDone." ) assert "" not in data["response"] assert "[REDACTED]" in data["response"] def test_strips_available_agents_fragment(self) -> None: with TestClient(app) as client: data = self._post_chat( client, "Available agents: task_agent, calendar_agent" ) assert "[REDACTED]" in data["response"] def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None: """GET /api/v1/plans/playbook should pass through the sanitizer untouched.""" token = self._token() with TestClient(app) as client: resp = client.get( "/api/v1/plans/playbook", headers=_auth_header(token), ) # The sanitizer should not interfere — just check it returns something # (200 or whatever the route returns; we only care it's not broken). assert resp.status_code in (200, 401, 403, 404) def test_sanitizer_preserves_empty_response(self) -> None: with TestClient(app) as client: data = self._post_chat(client, "") assert data["response"] == ""