step 9 complete: auth middleware, tier-aware rate limiter, and response sanitizer
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
304
tests/test_middleware.py
Normal file
304
tests/test_middleware.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer.
|
||||
|
||||
Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT).
|
||||
Rate limit: use unique user UUIDs per test so windows are independent;
|
||||
the free-tier threshold (20 req/min) is exercised directly.
|
||||
Sanitizer: the orchestrator is mocked to inject controlled prompt
|
||||
fragments, and the chat endpoint response body is inspected.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from jose import jwt
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.main import app
|
||||
from app.schemas import ChatResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CHAT_BODY = {
|
||||
"message": "hello",
|
||||
"context": {
|
||||
"user_profile": {},
|
||||
"relevant_documents": [],
|
||||
"recent_tasks": [],
|
||||
"conversation_history": [],
|
||||
},
|
||||
"execution_mode": "direct",
|
||||
}
|
||||
|
||||
|
||||
def _make_jwt(
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
email: str = "test@example.com",
|
||||
tier: str = "free",
|
||||
exp_offset: int = 3600,
|
||||
secret: str | None = None,
|
||||
include_sub: bool = True,
|
||||
) -> str:
|
||||
"""Mint a test JWT signed with the configured (or custom) secret."""
|
||||
uid = user_id or str(uuid.uuid4())
|
||||
now = int(time.time())
|
||||
payload: dict = {
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": now + exp_offset,
|
||||
"iat": now,
|
||||
}
|
||||
if include_sub:
|
||||
payload["sub"] = uid
|
||||
key = secret or settings.JWT_SECRET
|
||||
return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def _auth_header(token: str) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthMiddleware:
|
||||
"""Tests exercised via GET /api/v1/auth/me."""
|
||||
|
||||
def test_valid_token_returns_profile(self) -> None:
|
||||
uid = str(uuid.uuid4())
|
||||
token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro")
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == uid
|
||||
assert data["email"] == "alice@example.com"
|
||||
assert data["tier"] == "pro"
|
||||
|
||||
def test_missing_token_returns_401(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_expired_token_returns_401(self) -> None:
|
||||
token = _make_jwt(exp_offset=-1) # already expired
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_wrong_signature_returns_401(self) -> None:
|
||||
token = _make_jwt(secret="totally-wrong-secret")
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_missing_sub_claim_returns_401(self) -> None:
|
||||
token = _make_jwt(include_sub=False)
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_malformed_token_returns_401(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get(
|
||||
"/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate limiter middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Each test uses a fresh unique user_id so windows never collide."""
|
||||
|
||||
def _unique_token(self, tier: str = "free") -> str:
|
||||
return _make_jwt(user_id=str(uuid.uuid4()), tier=tier)
|
||||
|
||||
def test_free_tier_allows_up_to_20_requests(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_free_tier_blocks_21st_request(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_429_includes_retry_after_header(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
assert "retry-after" in resp.headers
|
||||
retry_after = int(resp.headers["retry-after"])
|
||||
assert retry_after >= 1
|
||||
|
||||
def test_429_response_has_detail_field(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
assert "detail" in resp.json()
|
||||
|
||||
def test_pro_tier_allows_60_requests(self) -> None:
|
||||
token = self._unique_token("pro")
|
||||
with TestClient(app) as client:
|
||||
# Sample: first 60 succeed, 61st is blocked.
|
||||
for _ in range(60):
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_independent_users_have_separate_windows(self) -> None:
|
||||
token_a = self._unique_token("free")
|
||||
token_b = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
# Exhaust user A's quota.
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token_a))
|
||||
assert (
|
||||
client.get(
|
||||
"/api/v1/auth/me", headers=_auth_header(token_a)
|
||||
).status_code
|
||||
== 429
|
||||
)
|
||||
# User B's quota is untouched.
|
||||
resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b))
|
||||
assert resp_b.status_code == 200
|
||||
|
||||
def test_exempt_path_register_never_rate_limited(self) -> None:
|
||||
"""POST /auth/register is exempt — 25 calls should never return 429."""
|
||||
with TestClient(app) as client:
|
||||
for i in range(25):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"},
|
||||
)
|
||||
# 201 on first, 409 on email collision — but never 429.
|
||||
assert resp.status_code != 429
|
||||
|
||||
def test_exempt_path_login_never_rate_limited(self) -> None:
|
||||
"""POST /auth/login is exempt — multiple failed attempts are not rate-limited."""
|
||||
with TestClient(app) as client:
|
||||
for _ in range(25):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "nosuchuser@example.com", "password": "wrong"},
|
||||
)
|
||||
assert resp.status_code != 429
|
||||
|
||||
def test_exempt_path_health_never_rate_limited(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
for _ in range(25):
|
||||
resp = client.get("/api/v1/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sanitizer middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizerMiddleware:
|
||||
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
|
||||
|
||||
_CHAT_PATH = "/api/v1/chat"
|
||||
|
||||
def _token(self) -> str:
|
||||
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
||||
|
||||
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
||||
mock_response = ChatResponse(response=response_text, actions=[])
|
||||
with patch(
|
||||
"app.api.routes.chat.orchestrate",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
resp = client.post(
|
||||
self._CHAT_PATH,
|
||||
json=_CHAT_BODY,
|
||||
headers=_auth_header(self._token()),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
return resp.json()
|
||||
|
||||
def test_clean_response_passes_through_unchanged(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(client, "Sure, I created the task for you.")
|
||||
assert data["response"] == "Sure, I created the task for you."
|
||||
|
||||
def test_strips_system_prompt_opener(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "You are an intent classifier. Route to task_agent."
|
||||
)
|
||||
assert "You are" not in data["response"]
|
||||
assert "[REDACTED]" in data["response"]
|
||||
|
||||
def test_strips_known_fingerprint(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "Respond with just the agent name and nothing else."
|
||||
)
|
||||
assert data["response"] == "[REDACTED]"
|
||||
|
||||
def test_strips_tool_schema_fragment(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, 'Here is the schema: {"type": "function", "name": "foo"}'
|
||||
)
|
||||
assert '"type": "function"' not in data["response"]
|
||||
|
||||
def test_strips_reasoning_tag(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "<thinking>I should route this to calendar_agent</thinking>Done."
|
||||
)
|
||||
assert "<thinking>" not in data["response"]
|
||||
assert "[REDACTED]" in data["response"]
|
||||
|
||||
def test_strips_available_agents_fragment(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "Available agents: task_agent, calendar_agent"
|
||||
)
|
||||
assert "[REDACTED]" in data["response"]
|
||||
|
||||
def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None:
|
||||
"""GET /api/v1/plans/playbook should pass through the sanitizer untouched."""
|
||||
token = self._token()
|
||||
with TestClient(app) as client:
|
||||
resp = client.get(
|
||||
"/api/v1/plans/playbook",
|
||||
headers=_auth_header(token),
|
||||
)
|
||||
# The sanitizer should not interfere — just check it returns something
|
||||
# (200 or whatever the route returns; we only care it's not broken).
|
||||
assert resp.status_code in (200, 401, 403, 404)
|
||||
|
||||
def test_sanitizer_preserves_empty_response(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(client, "")
|
||||
assert data["response"] == ""
|
||||
Reference in New Issue
Block a user