Files
api/tests/test_middleware.py

305 lines
11 KiB
Python

"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer.
Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT).
Rate limit: use unique user UUIDs per test so windows are independent;
the free-tier threshold (20 req/min) is exercised directly.
Sanitizer: the orchestrator is mocked to inject controlled prompt
fragments, and the chat endpoint response body is inspected.
"""
from __future__ import annotations
import time
import uuid
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from jose import jwt
from app.config.settings import settings
from app.main import app
from app.schemas import ChatResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_CHAT_BODY = {
"message": "hello",
"context": {
"user_profile": {},
"relevant_documents": [],
"recent_tasks": [],
"conversation_history": [],
},
"execution_mode": "direct",
}
def _make_jwt(
*,
user_id: str | None = None,
email: str = "test@example.com",
tier: str = "free",
exp_offset: int = 3600,
secret: str | None = None,
include_sub: bool = True,
) -> str:
"""Mint a test JWT signed with the configured (or custom) secret."""
uid = user_id or str(uuid.uuid4())
now = int(time.time())
payload: dict = {
"email": email,
"tier": tier,
"exp": now + exp_offset,
"iat": now,
}
if include_sub:
payload["sub"] = uid
key = secret or settings.JWT_SECRET
return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM)
def _auth_header(token: str) -> dict[str, str]:
return {"Authorization": f"Bearer {token}"}
# ---------------------------------------------------------------------------
# Auth middleware
# ---------------------------------------------------------------------------
class TestAuthMiddleware:
"""Tests exercised via GET /api/v1/auth/me."""
def test_valid_token_returns_profile(self) -> None:
uid = str(uuid.uuid4())
token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro")
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
data = resp.json()
assert data["id"] == uid
assert data["email"] == "alice@example.com"
assert data["tier"] == "pro"
def test_missing_token_returns_401(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
def test_expired_token_returns_401(self) -> None:
token = _make_jwt(exp_offset=-1) # already expired
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 401
def test_wrong_signature_returns_401(self) -> None:
token = _make_jwt(secret="totally-wrong-secret")
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 401
def test_missing_sub_claim_returns_401(self) -> None:
token = _make_jwt(include_sub=False)
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 401
def test_malformed_token_returns_401(self) -> None:
with TestClient(app) as client:
resp = client.get(
"/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"}
)
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# Rate limiter middleware
# ---------------------------------------------------------------------------
class TestRateLimitMiddleware:
"""Each test uses a fresh unique user_id so windows never collide."""
def _unique_token(self, tier: str = "free") -> str:
return _make_jwt(user_id=str(uuid.uuid4()), tier=tier)
def test_free_tier_allows_up_to_20_requests(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
def test_free_tier_blocks_21st_request(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token))
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
def test_429_includes_retry_after_header(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token))
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
assert "retry-after" in resp.headers
retry_after = int(resp.headers["retry-after"])
assert retry_after >= 1
def test_429_response_has_detail_field(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token))
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
assert "detail" in resp.json()
def test_pro_tier_allows_60_requests(self) -> None:
token = self._unique_token("pro")
with TestClient(app) as client:
# Sample: first 60 succeed, 61st is blocked.
for _ in range(60):
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
def test_independent_users_have_separate_windows(self) -> None:
token_a = self._unique_token("free")
token_b = self._unique_token("free")
with TestClient(app) as client:
# Exhaust user A's quota.
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token_a))
assert (
client.get(
"/api/v1/auth/me", headers=_auth_header(token_a)
).status_code
== 429
)
# User B's quota is untouched.
resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b))
assert resp_b.status_code == 200
def test_exempt_path_register_never_rate_limited(self) -> None:
"""POST /auth/register is exempt — 25 calls should never return 429."""
with TestClient(app) as client:
for i in range(25):
resp = client.post(
"/api/v1/auth/register",
json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"},
)
# 201 on first, 409 on email collision — but never 429.
assert resp.status_code != 429
def test_exempt_path_login_never_rate_limited(self) -> None:
"""POST /auth/login is exempt — multiple failed attempts are not rate-limited."""
with TestClient(app) as client:
for _ in range(25):
resp = client.post(
"/api/v1/auth/login",
json={"email": "nosuchuser@example.com", "password": "wrong"},
)
assert resp.status_code != 429
def test_exempt_path_health_never_rate_limited(self) -> None:
with TestClient(app) as client:
for _ in range(25):
resp = client.get("/api/v1/health")
assert resp.status_code == 200
# ---------------------------------------------------------------------------
# Sanitizer middleware
# ---------------------------------------------------------------------------
class TestSanitizerMiddleware:
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
_CHAT_PATH = "/api/v1/chat"
def _token(self) -> str:
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
def _post_chat(self, client: TestClient, response_text: str) -> dict:
mock_response = ChatResponse(response=response_text, actions=[])
with patch(
"app.api.routes.chat.orchestrate",
new_callable=AsyncMock,
return_value=mock_response,
):
resp = client.post(
self._CHAT_PATH,
json=_CHAT_BODY,
headers=_auth_header(self._token()),
)
assert resp.status_code == 200
return resp.json()
def test_clean_response_passes_through_unchanged(self) -> None:
with TestClient(app) as client:
data = self._post_chat(client, "Sure, I created the task for you.")
assert data["response"] == "Sure, I created the task for you."
def test_strips_system_prompt_opener(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "You are an intent classifier. Route to task_agent."
)
assert "You are" not in data["response"]
assert "[REDACTED]" in data["response"]
def test_strips_known_fingerprint(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "Respond with just the agent name and nothing else."
)
assert data["response"] == "[REDACTED]"
def test_strips_tool_schema_fragment(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, 'Here is the schema: {"type": "function", "name": "foo"}'
)
assert '"type": "function"' not in data["response"]
def test_strips_reasoning_tag(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "<thinking>I should route this to calendar_agent</thinking>Done."
)
assert "<thinking>" not in data["response"]
assert "[REDACTED]" in data["response"]
def test_strips_available_agents_fragment(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "Available agents: task_agent, calendar_agent"
)
assert "[REDACTED]" in data["response"]
def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None:
"""GET /api/v1/plans/playbook should pass through the sanitizer untouched."""
token = self._token()
with TestClient(app) as client:
resp = client.get(
"/api/v1/plans/playbook",
headers=_auth_header(token),
)
# The sanitizer should not interfere — just check it returns something
# (200 or whatever the route returns; we only care it's not broken).
assert resp.status_code in (200, 401, 403, 404)
def test_sanitizer_preserves_empty_response(self) -> None:
with TestClient(app) as client:
data = self._post_chat(client, "")
assert data["response"] == ""