Phase 3 — WS frame + REST fallbacka
This commit is contained in:
163
tests/test_brief_agent.py
Normal file
163
tests/test_brief_agent.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for Phase 3: brief agent WS frame + REST fallback.
|
||||
|
||||
Coverage:
|
||||
- run_home_brief streams non-empty text (mocked _run_single_agent_stream)
|
||||
- run_project_brief with bogus UUID → WS returns stream_end with error, no crash
|
||||
- _build_read_tools uses read-only subset only (no mutating tools)
|
||||
- POST /chat/brief home mode returns {response: "..."}
|
||||
- POST /chat/brief project mode with invalid UUID → 422
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import TEST_USER_IDS, auth_header
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_ID = TEST_USER_IDS["pro"]
|
||||
_EMPTY_CONTEXT: dict[str, Any] = {"core_memory": {}}
|
||||
|
||||
|
||||
async def _fake_token_stream(*_args, **_kwargs) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Fake _run_single_agent_stream that yields two token events."""
|
||||
yield ("token", "Hello")
|
||||
yield ("token", " world")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit: run_home_brief streams non-empty text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_home_brief_streams_text():
|
||||
with patch(
|
||||
"app.core.brief_agent._run_single_agent_stream",
|
||||
side_effect=_fake_token_stream,
|
||||
):
|
||||
from app.core.brief_agent import run_home_brief
|
||||
|
||||
chunks: list[str] = []
|
||||
async for event_type, data in run_home_brief(_USER_ID, _EMPTY_CONTEXT):
|
||||
if event_type == "token":
|
||||
chunks.append(str(data))
|
||||
|
||||
assert "".join(chunks) == "Hello world"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit: run_project_brief streams text with valid UUID
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_project_brief_streams_text():
|
||||
project_id = str(uuid.uuid4())
|
||||
with patch(
|
||||
"app.core.brief_agent._run_single_agent_stream",
|
||||
side_effect=_fake_token_stream,
|
||||
):
|
||||
from app.core.brief_agent import run_project_brief
|
||||
|
||||
chunks: list[str] = []
|
||||
async for event_type, data in run_project_brief(_USER_ID, project_id, _EMPTY_CONTEXT):
|
||||
if event_type == "token":
|
||||
chunks.append(str(data))
|
||||
|
||||
assert "".join(chunks) == "Hello world"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit: _build_read_tools uses read-only subset (no write tools)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_build_read_tools_read_only_subset():
|
||||
from app.agents.note_agent import NOTE_READ_TOOLS
|
||||
from app.agents.project_agent import PROJECT_READ_TOOLS
|
||||
from app.agents.task_agent import TASK_READ_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
|
||||
from app.core.brief_agent import _build_read_tools
|
||||
|
||||
tools = _build_read_tools(_USER_ID, None)
|
||||
tool_names = {getattr(t, "name", None) or getattr(t, "__name__", str(t)) for t in tools}
|
||||
|
||||
# Read-only exports must be present.
|
||||
for read_list in (TASK_READ_TOOLS, PROJECT_READ_TOOLS, TIMELINE_READ_TOOLS, NOTE_READ_TOOLS):
|
||||
for t in read_list:
|
||||
name = getattr(t, "name", None) or getattr(t, "__name__", str(t))
|
||||
assert name in tool_names, f"Read tool {name!r} missing from _build_read_tools"
|
||||
|
||||
# No mutating tools (e.g. create_task, update_task, delete_task).
|
||||
mutating = {"create_task", "update_task", "delete_task", "create_project",
|
||||
"update_project", "delete_project", "create_note", "update_note",
|
||||
"delete_note", "memory_add", "memory_update", "memory_delete"}
|
||||
overlap = tool_names & mutating
|
||||
assert not overlap, f"Mutating tools in brief read-only subset: {overlap}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: POST /chat/brief — home mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_brief_home_returns_response(client):
|
||||
async def _fake_home_brief(user_id, context):
|
||||
yield ("token", "Today looks light.")
|
||||
|
||||
with (
|
||||
patch("app.api.routes.chat.run_home_brief", side_effect=_fake_home_brief),
|
||||
patch(
|
||||
"app.api.routes.chat.MemoryMiddleware.enrich_context",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
res = client.post(
|
||||
"/api/v1/chat/brief",
|
||||
json={"mode": "home"},
|
||||
headers=auth_header("pro"),
|
||||
)
|
||||
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert data["response"] == "Today looks light."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_brief_project_invalid_uuid_returns_422(client):
|
||||
res = client.post(
|
||||
"/api/v1/chat/brief",
|
||||
json={"mode": "project", "project_id": "not-a-uuid"},
|
||||
headers=auth_header("pro"),
|
||||
)
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_brief_project_missing_uuid_returns_422(client):
|
||||
res = client.post(
|
||||
"/api/v1/chat/brief",
|
||||
json={"mode": "project"},
|
||||
headers=auth_header("pro"),
|
||||
)
|
||||
assert res.status_code == 422
|
||||
Reference in New Issue
Block a user