Phase 3 — WS frame + REST fallbacka

This commit is contained in:
Roberto Musso
2026-04-18 22:18:53 +02:00
parent 0b5ef48463
commit d5fea95561
20 changed files with 613 additions and 15 deletions

View File

@@ -382,7 +382,6 @@ async def test_eval_runner(runner_case, pytestconfig):
await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args
inserts = [c for c in calls if c["action"] == "insert"]
score, comment = _evaluate_case(case, calls, kwargs)
if obs is not None:

163
tests/test_brief_agent.py Normal file
View File

@@ -0,0 +1,163 @@
"""Tests for Phase 3: brief agent WS frame + REST fallback.
Coverage:
- run_home_brief streams non-empty text (mocked _run_single_agent_stream)
- run_project_brief with bogus UUID → WS returns stream_end with error, no crash
- _build_read_tools uses read-only subset only (no mutating tools)
- POST /chat/brief home mode returns {response: "..."}
- POST /chat/brief project mode with invalid UUID → 422
"""
from __future__ import annotations
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_USER_ID = TEST_USER_IDS["pro"]
_EMPTY_CONTEXT: dict[str, Any] = {"core_memory": {}}
async def _fake_token_stream(*_args, **_kwargs) -> AsyncGenerator[tuple[str, Any], None]:
"""Fake _run_single_agent_stream that yields two token events."""
yield ("token", "Hello")
yield ("token", " world")
# ---------------------------------------------------------------------------
# Unit: run_home_brief streams non-empty text
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_home_brief_streams_text():
with patch(
"app.core.brief_agent._run_single_agent_stream",
side_effect=_fake_token_stream,
):
from app.core.brief_agent import run_home_brief
chunks: list[str] = []
async for event_type, data in run_home_brief(_USER_ID, _EMPTY_CONTEXT):
if event_type == "token":
chunks.append(str(data))
assert "".join(chunks) == "Hello world"
# ---------------------------------------------------------------------------
# Unit: run_project_brief streams text with valid UUID
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_project_brief_streams_text():
project_id = str(uuid.uuid4())
with patch(
"app.core.brief_agent._run_single_agent_stream",
side_effect=_fake_token_stream,
):
from app.core.brief_agent import run_project_brief
chunks: list[str] = []
async for event_type, data in run_project_brief(_USER_ID, project_id, _EMPTY_CONTEXT):
if event_type == "token":
chunks.append(str(data))
assert "".join(chunks) == "Hello world"
# ---------------------------------------------------------------------------
# Unit: _build_read_tools uses read-only subset (no write tools)
# ---------------------------------------------------------------------------
def test_build_read_tools_read_only_subset():
from app.agents.note_agent import NOTE_READ_TOOLS
from app.agents.project_agent import PROJECT_READ_TOOLS
from app.agents.task_agent import TASK_READ_TOOLS
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
from app.core.brief_agent import _build_read_tools
tools = _build_read_tools(_USER_ID, None)
tool_names = {getattr(t, "name", None) or getattr(t, "__name__", str(t)) for t in tools}
# Read-only exports must be present.
for read_list in (TASK_READ_TOOLS, PROJECT_READ_TOOLS, TIMELINE_READ_TOOLS, NOTE_READ_TOOLS):
for t in read_list:
name = getattr(t, "name", None) or getattr(t, "__name__", str(t))
assert name in tool_names, f"Read tool {name!r} missing from _build_read_tools"
# No mutating tools (e.g. create_task, update_task, delete_task).
mutating = {"create_task", "update_task", "delete_task", "create_project",
"update_project", "delete_project", "create_note", "update_note",
"delete_note", "memory_add", "memory_update", "memory_delete"}
overlap = tool_names & mutating
assert not overlap, f"Mutating tools in brief read-only subset: {overlap}"
# ---------------------------------------------------------------------------
# Integration: POST /chat/brief — home mode
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _override_db(db_session):
from app.db import get_session
from app.main import app
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
@pytest.mark.asyncio
async def test_rest_brief_home_returns_response(client):
async def _fake_home_brief(user_id, context):
yield ("token", "Today looks light.")
with (
patch("app.api.routes.chat.run_home_brief", side_effect=_fake_home_brief),
patch(
"app.api.routes.chat.MemoryMiddleware.enrich_context",
new=AsyncMock(return_value={}),
),
):
res = client.post(
"/api/v1/chat/brief",
json={"mode": "home"},
headers=auth_header("pro"),
)
assert res.status_code == 200
data = res.json()
assert data["response"] == "Today looks light."
@pytest.mark.asyncio
async def test_rest_brief_project_invalid_uuid_returns_422(client):
res = client.post(
"/api/v1/chat/brief",
json={"mode": "project", "project_id": "not-a-uuid"},
headers=auth_header("pro"),
)
assert res.status_code == 422
@pytest.mark.asyncio
async def test_rest_brief_project_missing_uuid_returns_422(client):
res = client.post(
"/api/v1/chat/brief",
json={"mode": "project"},
headers=auth_header("pro"),
)
assert res.status_code == 422

View File

@@ -201,7 +201,6 @@ def test_ws_device_invalid_first_frame_closes(client):
def test_ws_device_tool_result_dispatched(client):
"""tool_result frame is routed to the DeviceConnectionManager."""
token = make_jwt(tier="free")
user_id = TEST_USER_IDS["free"]
from app.core.device_manager import device_manager as dm

View File

@@ -328,7 +328,7 @@ def _make_gmail_message(
class TestGmailClientFetchMessages:
"""GmailClient.fetch_messages tests with mocked Google API."""
def _make_client(self) -> "GmailClient":
def _make_client(self):
from app.integrations.gmail import GmailClient
return GmailClient(_TOKEN_DICT)
@@ -509,7 +509,7 @@ def _make_graph_teams_message(
class TestMSGraphClientFetchEmails:
"""MSGraphClient.fetch_emails tests with mocked httpx."""
def _make_client(self) -> "MSGraphClient":
def _make_client(self):
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(_MS_TOKEN_DICT)
@@ -608,7 +608,7 @@ class TestMSGraphClientFetchEmails:
class TestMSGraphClientFetchMessages:
"""MSGraphClient.fetch_messages (Teams) tests."""
def _make_client(self) -> "MSGraphClient":
def _make_client(self):
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(_MS_TOKEN_DICT)