289 lines
9.1 KiB
Python
289 lines
9.1 KiB
Python
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import date, timedelta
|
|
from types import SimpleNamespace
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from langchain_core.messages import AIMessage, ToolMessage
|
|
|
|
from app.core.deep_agent import (
|
|
_infer_floating_domain,
|
|
_normalize_tagged_list_lines,
|
|
run_floating,
|
|
run_floating_stream,
|
|
run_home,
|
|
)
|
|
|
|
|
|
class _FakeTool:
|
|
name = "list_tasks"
|
|
|
|
async def ainvoke(self, args):
|
|
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
|
|
|
|
|
|
class _FakeLLM:
|
|
def __init__(self) -> None:
|
|
self.agent_calls = 0
|
|
|
|
def bind_tools(self, _tools):
|
|
return self
|
|
|
|
async def ainvoke(self, messages):
|
|
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
|
|
if "strict domain classifier" in system_prompt:
|
|
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
|
|
|
|
self.agent_calls += 1
|
|
if self.agent_calls == 1:
|
|
return AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "call-1",
|
|
"name": "list_tasks",
|
|
"args": {"project_id": "proj-1"},
|
|
}
|
|
],
|
|
)
|
|
|
|
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
|
assert tool_messages, "Expected at least one tool message"
|
|
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
|
|
|
|
async def astream(self, _messages):
|
|
yield SimpleNamespace(content="stream-")
|
|
yield SimpleNamespace(content="ok")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_home_uses_mocked_tool_result():
|
|
fake_llm = _FakeLLM()
|
|
|
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
):
|
|
out = await run_home("user-1", "list my tasks", {})
|
|
|
|
assert "Final answer from mocked tool" in out
|
|
assert "Mock Task" in out
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
|
fake_llm = _FakeLLM()
|
|
|
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
):
|
|
events = []
|
|
async for event in run_floating_stream(
|
|
"user-1",
|
|
"show me timeline updates",
|
|
{"scope": {"type": "timeline", "id": "tl-1"}},
|
|
):
|
|
events.append(event)
|
|
|
|
assert events[0] == (
|
|
"floating_domain",
|
|
{"type": "timeline", "id": "tl-1", "section": None},
|
|
)
|
|
assert ("token", "stream-") in events
|
|
assert ("token", "ok") in events
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
|
class _ClassifierOnlyLLM:
|
|
async def ainvoke(self, _messages):
|
|
return AIMessage(
|
|
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
|
)
|
|
|
|
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
|
domain = await _infer_floating_domain(
|
|
"Quali sono i miei task per il progetto X",
|
|
{
|
|
"scope": {"type": "timeline"},
|
|
"resolved_project_id": "213213-312321-312312-421321",
|
|
},
|
|
)
|
|
|
|
assert domain == {
|
|
"type": "project",
|
|
"id": "213213-312321-312312-421321",
|
|
"section": "task",
|
|
}
|
|
|
|
|
|
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
|
raw = (
|
|
"Certo!\n\n"
|
|
"1. **Task A** — priorita high <task>[task-1]</task>\n"
|
|
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
|
|
)
|
|
|
|
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
|
|
|
|
assert "<task>[task-1]</task>" in out
|
|
assert "<task>[task-2]</task>" in out
|
|
assert "Task A" not in out
|
|
assert "Task B" not in out
|
|
|
|
|
|
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
|
|
today = date.today()
|
|
tomorrow = today + timedelta(days=1)
|
|
yesterday = today - timedelta(days=1)
|
|
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
|
|
|
|
raw = "\n".join(
|
|
[
|
|
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
|
|
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
|
|
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
|
|
]
|
|
)
|
|
|
|
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
|
|
|
|
assert "<timeline>[tl-next]</timeline>" in out
|
|
assert "<timeline>[tl-old]</timeline>" not in out
|
|
assert "<timeline>[tl-future]</timeline>" not in out
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
|
fake_llm = _FakeLLM()
|
|
|
|
async def _fake_run_single_agent(**_kwargs):
|
|
return (
|
|
"Hai 1 task:\\n"
|
|
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
)
|
|
|
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
):
|
|
text, _domain = await run_floating(
|
|
"user-1",
|
|
"quali task ho?",
|
|
{"scope": {"type": "task"}},
|
|
)
|
|
|
|
assert "<task>" not in text
|
|
assert "</task>" not in text
|
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
|
fake_llm = _FakeLLM()
|
|
|
|
async def _fake_stream(**_kwargs):
|
|
yield "token", "Hai 1 task:\\n"
|
|
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
|
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
):
|
|
events = []
|
|
async for event in run_floating_stream(
|
|
"user-1",
|
|
"quali task ho?",
|
|
{"scope": {"type": "task"}},
|
|
):
|
|
events.append(event)
|
|
|
|
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
|
combined = "".join(token_events)
|
|
assert "<task>" not in combined
|
|
assert "</task>" not in combined
|
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
|
class _NoChunkLLM:
|
|
def __init__(self) -> None:
|
|
self.calls = 0
|
|
|
|
def bind_tools(self, _tools):
|
|
return self
|
|
|
|
async def ainvoke(self, _messages):
|
|
self.calls += 1
|
|
if self.calls == 1:
|
|
return AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "call-1",
|
|
"name": "list_tasks",
|
|
"args": {},
|
|
}
|
|
],
|
|
)
|
|
return AIMessage(content="No notes found.")
|
|
|
|
async def astream(self, _messages):
|
|
if False:
|
|
yield None
|
|
|
|
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
):
|
|
events = []
|
|
async for event in run_floating_stream(
|
|
"user-1",
|
|
"quali sono le note?",
|
|
{"scope": {"type": "note"}},
|
|
):
|
|
events.append(event)
|
|
|
|
assert events[0][0] == "floating_domain"
|
|
assert ("token", "No notes found.") in events
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
|
fake_llm = _FakeLLM()
|
|
|
|
async def _fake_run_single_agent(**_kwargs):
|
|
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
|
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
):
|
|
text, _domain = await run_floating(
|
|
"user-1",
|
|
"quali task ho?",
|
|
{"scope": {"type": "task"}},
|
|
)
|
|
|
|
assert text == "No results found."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
|
fake_llm = _FakeLLM()
|
|
|
|
async def _fake_stream(**_kwargs):
|
|
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
|
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
):
|
|
events = []
|
|
async for event in run_floating_stream(
|
|
"user-1",
|
|
"quali task ho?",
|
|
{"scope": {"type": "task"}},
|
|
):
|
|
events.append(event)
|
|
|
|
assert ("token", "No results found.") in events
|