diff --git a/tests/test_deep_agent.py b/tests/test_deep_agent.py index 231ce0d..c09b53b 100644 --- a/tests/test_deep_agent.py +++ b/tests/test_deep_agent.py @@ -12,11 +12,8 @@ from langchain_core.messages import AIMessage, ToolMessage from app.core.deep_agent import ( _build_system_prompt, _datetime_context_injection, - _infer_floating_domain, _normalize_tagged_list_lines, _request_context_block, - run_floating, - run_floating_stream, run_home, ) @@ -75,57 +72,6 @@ async def test_run_home_uses_mocked_tool_result(): assert "Mock Task" in out -@pytest.mark.asyncio -async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result(): - fake_llm = _FakeLLM() - - with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch( - "app.core.deep_agent._all_tools", return_value=[_FakeTool()] - ): - events = [] - async for event in run_floating_stream( - "user-1", - "show me timeline updates", - {"scope": {"type": "timeline", "id": "tl-1"}}, - ): - events.append(event) - - assert events[0] == ( - "floating_domain", - {"type": "timeline", "id": "tl-1", "section": None}, - ) - # _run_single_agent_stream uses ainvoke (not astream); the final token is - # the second LLM response which echoes the tool result. - token_events = [e for e in events if e[0] == "token"] - assert token_events, "Expected at least one token event" - combined = "".join(str(e[1]) for e in token_events) - assert "Mock Task" in combined - - -@pytest.mark.asyncio -async def test_infer_floating_domain_prefers_message_intent_over_scope_type(): - class _ClassifierOnlyLLM: - async def ainvoke(self, _messages): - return AIMessage( - content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}' - ) - - with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()): - domain = await _infer_floating_domain( - "Quali sono i miei task per il progetto X", - { - "scope": {"type": "timeline"}, - "resolved_project_id": "213213-312321-312312-421321", - }, - ) - - assert domain == { - "type": "project", - "id": "213213-312321-312312-421321", - "section": "task", - } - - def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines(): raw = ( "Certo!\n\n" @@ -162,139 +108,6 @@ def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_ assert "[tl-future]" not in out -@pytest.mark.asyncio -async def test_run_floating_strips_xml_like_tags_from_final_text(): - fake_llm = _FakeLLM() - - async def _fake_run_single_agent(**_kwargs): - return ( - "Hai 1 task:\\n" - "Mail barra in prod [180faff3-507d-4d88-aba8-66f204eb59ef]" - ) - - with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch( - "app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent - ): - text, _domain = await run_floating( - "user-1", - "quali task ho?", - {"scope": {"type": "task"}}, - ) - - assert "" not in text - assert "" not in text - assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text - - -@pytest.mark.asyncio -async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text(): - fake_llm = _FakeLLM() - - async def _fake_stream(**_kwargs): - yield "token", "Hai 1 task:\\n" - yield "token", "Mail barra in prod [180faff3-507d-4d88-aba8-66f204eb59ef]" - - with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch( - "app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream - ): - events = [] - async for event in run_floating_stream( - "user-1", - "quali task ho?", - {"scope": {"type": "task"}}, - ): - events.append(event) - - token_events = [str(data) for event_type, data in events if event_type == "token"] - combined = "".join(token_events) - assert "" not in combined - assert "" not in combined - assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined - - -@pytest.mark.asyncio -async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty(): - class _NoChunkLLM: - def __init__(self) -> None: - self.calls = 0 - - def bind_tools(self, _tools): - return self - - async def ainvoke(self, _messages): - self.calls += 1 - if self.calls == 1: - return AIMessage( - content="", - tool_calls=[ - { - "id": "call-1", - "name": "list_tasks", - "args": {}, - } - ], - ) - return AIMessage(content="No notes found.") - - async def astream(self, _messages): - if False: - yield None - - with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch( - "app.core.deep_agent._all_tools", return_value=[_FakeTool()] - ): - events = [] - async for event in run_floating_stream( - "user-1", - "quali sono le note?", - {"scope": {"type": "note"}}, - ): - events.append(event) - - assert events[0][0] == "floating_domain" - assert ("token", "No notes found.") in events - - -@pytest.mark.asyncio -async def test_run_floating_returns_fallback_when_sanitization_would_empty_text(): - fake_llm = _FakeLLM() - - async def _fake_run_single_agent(**_kwargs): - return "[180faff3-507d-4d88-aba8-66f204eb59ef]" - - with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch( - "app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent - ): - text, _domain = await run_floating( - "user-1", - "quali task ho?", - {"scope": {"type": "task"}}, - ) - - assert text == "No results found." - - -@pytest.mark.asyncio -async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text(): - fake_llm = _FakeLLM() - - async def _fake_stream(**_kwargs): - yield "token", "[180faff3-507d-4d88-aba8-66f204eb59ef]" - - with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch( - "app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream - ): - events = [] - async for event in run_floating_stream( - "user-1", - "quali task ho?", - {"scope": {"type": "task"}}, - ): - events.append(event) - - assert ("token", "No results found.") in events - - # ── _datetime_context_injection ──────────────────────────────────────────────── def _fp(tz: str, now_iso: str) -> dict: diff --git a/tests/test_ws_unified.py b/tests/test_ws_unified.py index 2af4364..e1c9b1b 100644 --- a/tests/test_ws_unified.py +++ b/tests/test_ws_unified.py @@ -1,6 +1,6 @@ """Integration tests for the unified WebSocket handler (Step 5). -Tests the device WS endpoint with home_request and floating_request frames, +Tests the device WS endpoint with home_request frames, verifying that the correct v3 frame sequence is returned. LLM calls are mocked to avoid network dependency. @@ -34,7 +34,7 @@ def _override_db(db_session): def _recv_until_end(ws, max_frames: int = 20) -> list[dict]: - """Receive frames until stream_end (or stream_end inside floating flow), or max_frames.""" + """Receive frames until stream_end or max_frames.""" frames = [] for _ in range(max_frames): raw = ws.receive_text() @@ -49,11 +49,6 @@ async def _mock_home_stream(user_id, message, context): yield "token", "Hello" -async def _mock_floating_stream(user_id, message, context): - yield "floating_domain", {"type": "task", "id": None, "section": None} - yield "token", "Here is a summary" - - # ── tests ───────────────────────────────────────────────────────────────────── def test_home_request_produces_stream_frames(client): @@ -79,33 +74,6 @@ def test_home_request_produces_stream_frames(client): assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end) -def test_floating_request_produces_domain_frame(client): - """floating_request → floating_domain first, then stream_text*, stream_end.""" - token = make_jwt("power", user_id=USER_ID) - - with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream): - with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: - ws.send_text(json.dumps({ - "type": "device_hello", "device_id": "dev-2", "agent_ids": [] - })) - ws.send_text(json.dumps({ - "type": "floating_request", - "request_id": "p1", - "message": "Summarize this task", - "scope": {"type": "task", "id": "task-123"}, - })) - frames = _recv_until_end(ws) - - types = [f["type"] for f in frames] - assert WsFrameType.floating_domain in types - assert WsFrameType.stream_end in types - assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end) - - domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain) - assert domain_frame["domain"]["type"] == "task" - assert domain_frame["request_id"] == "p1" - - def test_home_request_request_id_propagated(client): """request_id in home_request is echoed in all response frames.""" token = make_jwt("power", user_id=USER_ID)