Files
api/tests/test_brief_agent.py
2026-04-18 22:18:53 +02:00

164 lines
5.6 KiB
Python

"""Tests for Phase 3: brief agent WS frame + REST fallback.
Coverage:
- run_home_brief streams non-empty text (mocked _run_single_agent_stream)
- run_project_brief with bogus UUID → WS returns stream_end with error, no crash
- _build_read_tools uses read-only subset only (no mutating tools)
- POST /chat/brief home mode returns {response: "..."}
- POST /chat/brief project mode with invalid UUID → 422
"""
from __future__ import annotations
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_USER_ID = TEST_USER_IDS["pro"]
_EMPTY_CONTEXT: dict[str, Any] = {"core_memory": {}}
async def _fake_token_stream(*_args, **_kwargs) -> AsyncGenerator[tuple[str, Any], None]:
"""Fake _run_single_agent_stream that yields two token events."""
yield ("token", "Hello")
yield ("token", " world")
# ---------------------------------------------------------------------------
# Unit: run_home_brief streams non-empty text
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_home_brief_streams_text():
with patch(
"app.core.brief_agent._run_single_agent_stream",
side_effect=_fake_token_stream,
):
from app.core.brief_agent import run_home_brief
chunks: list[str] = []
async for event_type, data in run_home_brief(_USER_ID, _EMPTY_CONTEXT):
if event_type == "token":
chunks.append(str(data))
assert "".join(chunks) == "Hello world"
# ---------------------------------------------------------------------------
# Unit: run_project_brief streams text with valid UUID
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_project_brief_streams_text():
project_id = str(uuid.uuid4())
with patch(
"app.core.brief_agent._run_single_agent_stream",
side_effect=_fake_token_stream,
):
from app.core.brief_agent import run_project_brief
chunks: list[str] = []
async for event_type, data in run_project_brief(_USER_ID, project_id, _EMPTY_CONTEXT):
if event_type == "token":
chunks.append(str(data))
assert "".join(chunks) == "Hello world"
# ---------------------------------------------------------------------------
# Unit: _build_read_tools uses read-only subset (no write tools)
# ---------------------------------------------------------------------------
def test_build_read_tools_read_only_subset():
from app.agents.note_agent import NOTE_READ_TOOLS
from app.agents.project_agent import PROJECT_READ_TOOLS
from app.agents.task_agent import TASK_READ_TOOLS
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
from app.core.brief_agent import _build_read_tools
tools = _build_read_tools(_USER_ID, None)
tool_names = {getattr(t, "name", None) or getattr(t, "__name__", str(t)) for t in tools}
# Read-only exports must be present.
for read_list in (TASK_READ_TOOLS, PROJECT_READ_TOOLS, TIMELINE_READ_TOOLS, NOTE_READ_TOOLS):
for t in read_list:
name = getattr(t, "name", None) or getattr(t, "__name__", str(t))
assert name in tool_names, f"Read tool {name!r} missing from _build_read_tools"
# No mutating tools (e.g. create_task, update_task, delete_task).
mutating = {"create_task", "update_task", "delete_task", "create_project",
"update_project", "delete_project", "create_note", "update_note",
"delete_note", "memory_add", "memory_update", "memory_delete"}
overlap = tool_names & mutating
assert not overlap, f"Mutating tools in brief read-only subset: {overlap}"
# ---------------------------------------------------------------------------
# Integration: POST /chat/brief — home mode
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _override_db(db_session):
from app.db import get_session
from app.main import app
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
@pytest.mark.asyncio
async def test_rest_brief_home_returns_response(client):
async def _fake_home_brief(user_id, context):
yield ("token", "Today looks light.")
with (
patch("app.api.routes.chat.run_home_brief", side_effect=_fake_home_brief),
patch(
"app.api.routes.chat.MemoryMiddleware.enrich_context",
new=AsyncMock(return_value={}),
),
):
res = client.post(
"/api/v1/chat/brief",
json={"mode": "home"},
headers=auth_header("pro"),
)
assert res.status_code == 200
data = res.json()
assert data["response"] == "Today looks light."
@pytest.mark.asyncio
async def test_rest_brief_project_invalid_uuid_returns_422(client):
res = client.post(
"/api/v1/chat/brief",
json={"mode": "project", "project_id": "not-a-uuid"},
headers=auth_header("pro"),
)
assert res.status_code == 422
@pytest.mark.asyncio
async def test_rest_brief_project_missing_uuid_returns_422(client):
res = client.post(
"/api/v1/chat/brief",
json={"mode": "project"},
headers=auth_header("pro"),
)
assert res.status_code == 422