- Add app/core/deep_agent.py with Home and Floating supervisor graphs using LangGraph create_react_agent (hierarchical pattern) - Strip ChatAgent classes from all 4 agent files, keep @tool functions - Rewrite output_formatter.py for event-based (token/tool_end/mutations) stream - Update device_ws.py to use run_home_stream/run_floating_stream - Rewrite chat.py REST route to use run_home - Add update_core_memory tool to both supervisors - Add langgraph>=0.3.0 to requirements.txt - Remove orchestrator.py, execution_plan.py, agent_registry.py, plans.py - Remove PlanAction, PlanStep, ExecutionPlan, execution_mode from schemas - Update all affected tests to match new API - Remove 6 deprecated test files for deleted modules - Clean up stale docstrings referencing removed orchestrator
320 lines
12 KiB
Python
320 lines
12 KiB
Python
"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer.
|
|
|
|
Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT).
|
|
Rate limit: use unique user UUIDs per test so windows are independent;
|
|
the free-tier threshold (20 req/min) is exercised directly.
|
|
Sanitizer: the orchestrator is mocked to inject controlled prompt
|
|
fragments, and the chat endpoint response body is inspected.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
import uuid
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from jose import jwt
|
|
|
|
from app.config.settings import settings
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from tests.conftest import TEST_USER_IDS
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Autouse: redirect all DB access to the in-memory SQLite test engine.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
"""Route all get_session calls to the test SQLite session."""
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
_CHAT_BODY = {
|
|
"message": "hello",
|
|
"context": {
|
|
"user_profile": {},
|
|
"relevant_documents": [],
|
|
"recent_tasks": [],
|
|
"conversation_history": [],
|
|
},
|
|
}
|
|
|
|
|
|
def _make_jwt(
|
|
*,
|
|
user_id: str | None = None,
|
|
email: str = "test@example.com",
|
|
tier: str = "free",
|
|
exp_offset: int = 3600,
|
|
secret: str | None = None,
|
|
include_sub: bool = True,
|
|
) -> str:
|
|
"""Mint a test JWT signed with the configured (or custom) secret."""
|
|
uid = user_id or str(uuid.uuid4())
|
|
now = int(time.time())
|
|
payload: dict = {
|
|
"email": email,
|
|
"tier": tier,
|
|
"exp": now + exp_offset,
|
|
"iat": now,
|
|
}
|
|
if include_sub:
|
|
payload["sub"] = uid
|
|
key = secret or settings.JWT_SECRET
|
|
return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM)
|
|
|
|
|
|
def _auth_header(token: str) -> dict[str, str]:
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auth middleware
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAuthMiddleware:
|
|
"""Tests exercised via GET /api/v1/auth/me."""
|
|
|
|
def test_valid_token_returns_profile(self) -> None:
|
|
# Use the seeded pro user so the subscription lookup returns 'pro'.
|
|
uid = TEST_USER_IDS["pro"]
|
|
token = _make_jwt(user_id=uid, email="pro@test.com", tier="pro")
|
|
with TestClient(app) as client:
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["id"] == uid
|
|
assert data["email"] == "pro@test.com"
|
|
assert data["tier"] == "pro"
|
|
|
|
def test_missing_token_returns_401(self) -> None:
|
|
with TestClient(app) as client:
|
|
resp = client.get("/api/v1/auth/me")
|
|
assert resp.status_code == 401
|
|
|
|
def test_expired_token_returns_401(self) -> None:
|
|
token = _make_jwt(exp_offset=-1) # already expired
|
|
with TestClient(app) as client:
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 401
|
|
|
|
def test_wrong_signature_returns_401(self) -> None:
|
|
token = _make_jwt(secret="totally-wrong-secret")
|
|
with TestClient(app) as client:
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 401
|
|
|
|
def test_missing_sub_claim_returns_401(self) -> None:
|
|
token = _make_jwt(include_sub=False)
|
|
with TestClient(app) as client:
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 401
|
|
|
|
def test_malformed_token_returns_401(self) -> None:
|
|
with TestClient(app) as client:
|
|
resp = client.get(
|
|
"/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"}
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Rate limiter middleware
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRateLimitMiddleware:
|
|
"""Each test uses a fresh unique user_id so windows never collide."""
|
|
|
|
def _unique_token(self, tier: str = "free") -> str:
|
|
return _make_jwt(user_id=str(uuid.uuid4()), tier=tier)
|
|
|
|
def test_free_tier_allows_up_to_20_requests(self) -> None:
|
|
token = self._unique_token("free")
|
|
with TestClient(app) as client:
|
|
for _ in range(20):
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 200
|
|
|
|
def test_free_tier_blocks_21st_request(self) -> None:
|
|
token = self._unique_token("free")
|
|
with TestClient(app) as client:
|
|
for _ in range(20):
|
|
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 429
|
|
|
|
def test_429_includes_retry_after_header(self) -> None:
|
|
token = self._unique_token("free")
|
|
with TestClient(app) as client:
|
|
for _ in range(20):
|
|
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 429
|
|
assert "retry-after" in resp.headers
|
|
retry_after = int(resp.headers["retry-after"])
|
|
assert retry_after >= 1
|
|
|
|
def test_429_response_has_detail_field(self) -> None:
|
|
token = self._unique_token("free")
|
|
with TestClient(app) as client:
|
|
for _ in range(20):
|
|
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 429
|
|
assert "detail" in resp.json()
|
|
|
|
def test_pro_tier_allows_60_requests(self) -> None:
|
|
token = self._unique_token("pro")
|
|
with TestClient(app) as client:
|
|
# Sample: first 60 succeed, 61st is blocked.
|
|
for _ in range(60):
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 200
|
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
|
assert resp.status_code == 429
|
|
|
|
def test_independent_users_have_separate_windows(self) -> None:
|
|
token_a = self._unique_token("free")
|
|
token_b = self._unique_token("free")
|
|
with TestClient(app) as client:
|
|
# Exhaust user A's quota.
|
|
for _ in range(20):
|
|
client.get("/api/v1/auth/me", headers=_auth_header(token_a))
|
|
assert (
|
|
client.get(
|
|
"/api/v1/auth/me", headers=_auth_header(token_a)
|
|
).status_code
|
|
== 429
|
|
)
|
|
# User B's quota is untouched.
|
|
resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b))
|
|
assert resp_b.status_code == 200
|
|
|
|
def test_exempt_path_register_never_rate_limited(self) -> None:
|
|
"""POST /auth/register is exempt — 25 calls should never return 429."""
|
|
with TestClient(app) as client:
|
|
for i in range(25):
|
|
resp = client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"},
|
|
)
|
|
# 201 on first, 409 on email collision — but never 429.
|
|
assert resp.status_code != 429
|
|
|
|
def test_exempt_path_login_never_rate_limited(self) -> None:
|
|
"""POST /auth/login is exempt — multiple failed attempts are not rate-limited."""
|
|
with TestClient(app) as client:
|
|
for _ in range(25):
|
|
resp = client.post(
|
|
"/api/v1/auth/login",
|
|
json={"email": "nosuchuser@example.com", "password": "wrong"},
|
|
)
|
|
assert resp.status_code != 429
|
|
|
|
def test_exempt_path_health_never_rate_limited(self) -> None:
|
|
with TestClient(app) as client:
|
|
for _ in range(25):
|
|
resp = client.get("/api/v1/health")
|
|
assert resp.status_code == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sanitizer middleware
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSanitizerMiddleware:
|
|
"""Mock ``run_home`` to inject controlled strings into chat responses."""
|
|
|
|
_CHAT_PATH = "/api/v1/chat"
|
|
|
|
def _token(self) -> str:
|
|
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
|
|
|
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
|
with patch(
|
|
"app.api.routes.chat.run_home",
|
|
new_callable=AsyncMock,
|
|
return_value=response_text,
|
|
):
|
|
resp = client.post(
|
|
self._CHAT_PATH,
|
|
json=_CHAT_BODY,
|
|
headers=_auth_header(self._token()),
|
|
)
|
|
assert resp.status_code == 200
|
|
return resp.json()
|
|
|
|
def test_clean_response_passes_through_unchanged(self) -> None:
|
|
with TestClient(app) as client:
|
|
data = self._post_chat(client, "Sure, I created the task for you.")
|
|
assert data["response"] == "Sure, I created the task for you."
|
|
|
|
def test_strips_system_prompt_opener(self) -> None:
|
|
with TestClient(app) as client:
|
|
data = self._post_chat(
|
|
client, "You are an intent classifier. Route to task_agent."
|
|
)
|
|
assert "You are" not in data["response"]
|
|
assert "[REDACTED]" in data["response"]
|
|
|
|
def test_strips_known_fingerprint(self) -> None:
|
|
with TestClient(app) as client:
|
|
data = self._post_chat(
|
|
client, "Respond with just the agent name and nothing else."
|
|
)
|
|
assert data["response"] == "[REDACTED]"
|
|
|
|
def test_strips_tool_schema_fragment(self) -> None:
|
|
with TestClient(app) as client:
|
|
data = self._post_chat(
|
|
client, 'Here is the schema: {"type": "function", "name": "foo"}'
|
|
)
|
|
assert '"type": "function"' not in data["response"]
|
|
|
|
def test_strips_reasoning_tag(self) -> None:
|
|
with TestClient(app) as client:
|
|
data = self._post_chat(
|
|
client, "<thinking>I should route this to calendar_agent</thinking>Done."
|
|
)
|
|
assert "<thinking>" not in data["response"]
|
|
assert "[REDACTED]" in data["response"]
|
|
|
|
def test_strips_available_agents_fragment(self) -> None:
|
|
with TestClient(app) as client:
|
|
data = self._post_chat(
|
|
client, "Available agents: task_agent, calendar_agent"
|
|
)
|
|
assert "[REDACTED]" in data["response"]
|
|
|
|
def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None:
|
|
"""GET /api/v1/plans/playbook should pass through the sanitizer untouched."""
|
|
token = self._token()
|
|
with TestClient(app) as client:
|
|
resp = client.get(
|
|
"/api/v1/plans/playbook",
|
|
headers=_auth_header(token),
|
|
)
|
|
# The sanitizer should not interfere — just check it returns something
|
|
# (200 or whatever the route returns; we only care it's not broken).
|
|
assert resp.status_code in (200, 401, 403, 404)
|
|
|
|
def test_sanitizer_preserves_empty_response(self) -> None:
|
|
with TestClient(app) as client:
|
|
data = self._post_chat(client, "")
|
|
assert data["response"] == ""
|