506 lines
18 KiB
Python
506 lines
18 KiB
Python
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import date, timedelta
|
|
from types import SimpleNamespace
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from langchain_core.messages import AIMessage, ToolMessage
|
|
|
|
from app.core.deep_agent import (
|
|
_build_system_prompt,
|
|
_datetime_context_injection,
|
|
_infer_floating_domain,
|
|
_normalize_tagged_list_lines,
|
|
_request_context_block,
|
|
run_floating,
|
|
run_floating_stream,
|
|
run_home,
|
|
)
|
|
|
|
|
|
class _FakeTool:
|
|
name = "list_tasks"
|
|
|
|
async def ainvoke(self, args):
|
|
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
|
|
|
|
|
|
class _FakeLLM:
|
|
def __init__(self) -> None:
|
|
self.agent_calls = 0
|
|
|
|
def bind_tools(self, _tools):
|
|
return self
|
|
|
|
async def ainvoke(self, messages):
|
|
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
|
|
if "strict domain classifier" in system_prompt:
|
|
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
|
|
|
|
self.agent_calls += 1
|
|
if self.agent_calls == 1:
|
|
return AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "call-1",
|
|
"name": "list_tasks",
|
|
"args": {"project_id": "proj-1"},
|
|
}
|
|
],
|
|
)
|
|
|
|
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
|
assert tool_messages, "Expected at least one tool message"
|
|
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
|
|
|
|
async def astream(self, _messages):
|
|
yield SimpleNamespace(content="stream-")
|
|
yield SimpleNamespace(content="ok")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_home_uses_mocked_tool_result():
|
|
fake_llm = _FakeLLM()
|
|
|
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
):
|
|
out = await run_home("user-1", "list my tasks", {})
|
|
|
|
assert "Final answer from mocked tool" in out
|
|
assert "Mock Task" in out
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
|
fake_llm = _FakeLLM()
|
|
|
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
):
|
|
events = []
|
|
async for event in run_floating_stream(
|
|
"user-1",
|
|
"show me timeline updates",
|
|
{"scope": {"type": "timeline", "id": "tl-1"}},
|
|
):
|
|
events.append(event)
|
|
|
|
assert events[0] == (
|
|
"floating_domain",
|
|
{"type": "timeline", "id": "tl-1", "section": None},
|
|
)
|
|
# _run_single_agent_stream uses ainvoke (not astream); the final token is
|
|
# the second LLM response which echoes the tool result.
|
|
token_events = [e for e in events if e[0] == "token"]
|
|
assert token_events, "Expected at least one token event"
|
|
combined = "".join(str(e[1]) for e in token_events)
|
|
assert "Mock Task" in combined
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
|
class _ClassifierOnlyLLM:
|
|
async def ainvoke(self, _messages):
|
|
return AIMessage(
|
|
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
|
)
|
|
|
|
with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()):
|
|
domain = await _infer_floating_domain(
|
|
"Quali sono i miei task per il progetto X",
|
|
{
|
|
"scope": {"type": "timeline"},
|
|
"resolved_project_id": "213213-312321-312312-421321",
|
|
},
|
|
)
|
|
|
|
assert domain == {
|
|
"type": "project",
|
|
"id": "213213-312321-312312-421321",
|
|
"section": "task",
|
|
}
|
|
|
|
|
|
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
|
raw = (
|
|
"Certo!\n\n"
|
|
"1. **Task A** — priorita high <task>[task-1]</task>\n"
|
|
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
|
|
)
|
|
|
|
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
|
|
|
|
assert "<task>[task-1]</task>" in out
|
|
assert "<task>[task-2]</task>" in out
|
|
assert "Task A" not in out
|
|
assert "Task B" not in out
|
|
|
|
|
|
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
|
|
today = date.today()
|
|
tomorrow = today + timedelta(days=1)
|
|
yesterday = today - timedelta(days=1)
|
|
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
|
|
|
|
raw = "\n".join(
|
|
[
|
|
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
|
|
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
|
|
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
|
|
]
|
|
)
|
|
|
|
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
|
|
|
|
assert "<timeline>[tl-next]</timeline>" in out
|
|
assert "<timeline>[tl-old]</timeline>" not in out
|
|
assert "<timeline>[tl-future]</timeline>" not in out
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
|
fake_llm = _FakeLLM()
|
|
|
|
async def _fake_run_single_agent(**_kwargs):
|
|
return (
|
|
"Hai 1 task:\\n"
|
|
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
)
|
|
|
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
):
|
|
text, _domain = await run_floating(
|
|
"user-1",
|
|
"quali task ho?",
|
|
{"scope": {"type": "task"}},
|
|
)
|
|
|
|
assert "<task>" not in text
|
|
assert "</task>" not in text
|
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
|
fake_llm = _FakeLLM()
|
|
|
|
async def _fake_stream(**_kwargs):
|
|
yield "token", "Hai 1 task:\\n"
|
|
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
|
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
):
|
|
events = []
|
|
async for event in run_floating_stream(
|
|
"user-1",
|
|
"quali task ho?",
|
|
{"scope": {"type": "task"}},
|
|
):
|
|
events.append(event)
|
|
|
|
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
|
combined = "".join(token_events)
|
|
assert "<task>" not in combined
|
|
assert "</task>" not in combined
|
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
|
class _NoChunkLLM:
|
|
def __init__(self) -> None:
|
|
self.calls = 0
|
|
|
|
def bind_tools(self, _tools):
|
|
return self
|
|
|
|
async def ainvoke(self, _messages):
|
|
self.calls += 1
|
|
if self.calls == 1:
|
|
return AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "call-1",
|
|
"name": "list_tasks",
|
|
"args": {},
|
|
}
|
|
],
|
|
)
|
|
return AIMessage(content="No notes found.")
|
|
|
|
async def astream(self, _messages):
|
|
if False:
|
|
yield None
|
|
|
|
with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch(
|
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
):
|
|
events = []
|
|
async for event in run_floating_stream(
|
|
"user-1",
|
|
"quali sono le note?",
|
|
{"scope": {"type": "note"}},
|
|
):
|
|
events.append(event)
|
|
|
|
assert events[0][0] == "floating_domain"
|
|
assert ("token", "No notes found.") in events
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
|
fake_llm = _FakeLLM()
|
|
|
|
async def _fake_run_single_agent(**_kwargs):
|
|
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
|
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
):
|
|
text, _domain = await run_floating(
|
|
"user-1",
|
|
"quali task ho?",
|
|
{"scope": {"type": "task"}},
|
|
)
|
|
|
|
assert text == "No results found."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
|
fake_llm = _FakeLLM()
|
|
|
|
async def _fake_stream(**_kwargs):
|
|
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
|
|
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
):
|
|
events = []
|
|
async for event in run_floating_stream(
|
|
"user-1",
|
|
"quali task ho?",
|
|
{"scope": {"type": "task"}},
|
|
):
|
|
events.append(event)
|
|
|
|
assert ("token", "No results found.") in events
|
|
|
|
|
|
# ── _datetime_context_injection ────────────────────────────────────────────────
|
|
|
|
def _fp(tz: str, now_iso: str) -> dict:
|
|
return {"timezone": tz, "now_iso": now_iso, "date_format": "dd/MM/yyyy", "time_format": "24h"}
|
|
|
|
|
|
def _parse_ms(block: str, key: str) -> tuple[int, int]:
|
|
"""Extract [start, end] from a 'key [start, end]' line in the DATE CONTEXT block."""
|
|
import re
|
|
m = re.search(rf"^{key}\s+\[(\d+),\s*(\d+)\]", block, re.MULTILINE)
|
|
assert m, f"Key '{key}' not found in block:\n{block}"
|
|
return int(m.group(1)), int(m.group(2))
|
|
|
|
|
|
def test_datetime_context_injection_europe_rome_late_evening():
|
|
"""22:16 CEST on 2026-04-26 — 'tomorrow' must be 2026-04-27 00:00→23:59:59.999 CEST."""
|
|
from zoneinfo import ZoneInfo
|
|
from datetime import datetime, timezone
|
|
|
|
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z")})
|
|
assert "DATE CONTEXT" in block
|
|
assert "Europe/Rome" in block
|
|
|
|
tz = ZoneInfo("Europe/Rome")
|
|
today_start = int(datetime(2026, 4, 26, tzinfo=tz).timestamp() * 1000)
|
|
today_end = int(datetime(2026, 4, 27, tzinfo=tz).timestamp() * 1000) - 1
|
|
tomorrow_start = today_end + 1
|
|
tomorrow_end = int(datetime(2026, 4, 28, tzinfo=tz).timestamp() * 1000) - 1
|
|
|
|
t_s, t_e = _parse_ms(block, "today")
|
|
assert t_s == today_start
|
|
assert t_e == today_end
|
|
|
|
tm_s, tm_e = _parse_ms(block, "tomorrow")
|
|
assert tm_s == tomorrow_start
|
|
assert tm_e == tomorrow_end
|
|
|
|
# Sanity: window is exactly 86 400 000 ms (1 day, CEST has no DST jump on this date)
|
|
assert today_end - today_start + 1 == 86_400_000
|
|
assert tomorrow_end - tomorrow_start + 1 == 86_400_000
|
|
|
|
|
|
def test_datetime_context_injection_utc():
|
|
"""UTC timezone: boundaries are clean UTC midnights."""
|
|
from datetime import datetime, timezone
|
|
|
|
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-01-15T10:00:00Z")})
|
|
t_s, t_e = _parse_ms(block, "today")
|
|
expected_start = int(datetime(2026, 1, 15, tzinfo=timezone.utc).timestamp() * 1000)
|
|
assert t_s == expected_start
|
|
assert t_e == expected_start + 86_400_000 - 1
|
|
|
|
|
|
def test_datetime_context_injection_dst_spring_forward():
|
|
"""Europe/Rome DST spring-forward 2026-03-29: that day is 23h, not 24h."""
|
|
from zoneinfo import ZoneInfo
|
|
from datetime import datetime
|
|
|
|
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-03-29T08:00:00Z")})
|
|
tz = ZoneInfo("Europe/Rome")
|
|
day_start = int(datetime(2026, 3, 29, tzinfo=tz).timestamp() * 1000)
|
|
day_end = int(datetime(2026, 3, 30, tzinfo=tz).timestamp() * 1000) - 1
|
|
|
|
t_s, t_e = _parse_ms(block, "today")
|
|
assert t_s == day_start
|
|
assert t_e == day_end
|
|
assert t_e - t_s + 1 == 23 * 3_600_000 # 23-hour day
|
|
|
|
|
|
def test_datetime_context_injection_dst_fall_back():
|
|
"""Europe/Rome DST fall-back 2026-10-25: that day is 25h."""
|
|
from zoneinfo import ZoneInfo
|
|
from datetime import datetime
|
|
|
|
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-10-25T08:00:00Z")})
|
|
tz = ZoneInfo("Europe/Rome")
|
|
day_start = int(datetime(2026, 10, 25, tzinfo=tz).timestamp() * 1000)
|
|
day_end = int(datetime(2026, 10, 26, tzinfo=tz).timestamp() * 1000) - 1
|
|
|
|
t_s, t_e = _parse_ms(block, "today")
|
|
assert t_s == day_start
|
|
assert t_e == day_end
|
|
assert t_e - t_s + 1 == 25 * 3_600_000 # 25-hour day
|
|
|
|
|
|
def test_datetime_context_injection_year_boundary():
|
|
"""Dec 31 → Jan 1: last_year, this_year, next_month cross year boundary correctly."""
|
|
from zoneinfo import ZoneInfo
|
|
from datetime import datetime
|
|
|
|
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-12-31T23:00:00Z")})
|
|
tz = ZoneInfo("UTC")
|
|
|
|
yr_s, yr_e = _parse_ms(block, "this_year")
|
|
assert yr_s == int(datetime(2026, 1, 1, tzinfo=tz).timestamp() * 1000)
|
|
assert yr_e == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000) - 1
|
|
|
|
ly_s, ly_e = _parse_ms(block, "last_year")
|
|
assert ly_s == int(datetime(2025, 1, 1, tzinfo=tz).timestamp() * 1000)
|
|
assert ly_e == yr_s - 1
|
|
|
|
nm_s, _ = _parse_ms(block, "next_month")
|
|
assert nm_s == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000)
|
|
|
|
|
|
def test_datetime_context_injection_missing_format_prefs():
|
|
assert _datetime_context_injection({}) == ""
|
|
assert _datetime_context_injection({"format_prefs": None}) == ""
|
|
assert _datetime_context_injection({"format_prefs": "bad"}) == ""
|
|
|
|
|
|
# ── _request_context_block ─────────────────────────────────────────────────────
|
|
|
|
def test_request_context_block_scope_and_project():
|
|
ctx = {"scope": {"type": "task", "id": "t-1"}, "resolved_project_id": "proj-uuid"}
|
|
block = _request_context_block(ctx)
|
|
assert "scope" in block
|
|
assert "resolved_project_id: proj-uuid" in block
|
|
|
|
|
|
def test_request_context_block_empty():
|
|
assert _request_context_block({}) == ""
|
|
assert _request_context_block({"scope": None}) == ""
|
|
|
|
|
|
# ── _build_system_prompt ───────────────────────────────────────────────────────
|
|
|
|
def test_build_system_prompt_substitutes_all_slots(monkeypatch):
|
|
"""All five slots must appear in the compiled output; no raw placeholder remains."""
|
|
# Patch get_prompt_or_fallback to return None prompt_obj so we use fallback .format() path
|
|
import app.core.deep_agent as da
|
|
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
|
|
|
|
ctx = {
|
|
"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z"),
|
|
"core_memory": {"language": "it"},
|
|
"relational_memory": ["Alice — client"],
|
|
"proactive_hints": ["User prefers morning meetings"],
|
|
"scope": {"type": "task"},
|
|
"resolved_project_id": "proj-1",
|
|
}
|
|
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
|
|
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, ctx)
|
|
|
|
# No unresolved placeholders
|
|
assert "{date_context}" not in text
|
|
assert "{language_instruction}" not in text
|
|
assert "{relational_memory}" not in text
|
|
assert "{proactive_hints}" not in text
|
|
assert "{request_context}" not in text
|
|
|
|
# Content was injected
|
|
assert "DATE CONTEXT" in text
|
|
assert "Italian" in text
|
|
assert "Alice" in text
|
|
assert "morning meetings" in text
|
|
assert "proj-1" in text
|
|
|
|
|
|
def test_build_system_prompt_empty_format_prefs(monkeypatch):
|
|
"""Missing format_prefs must not raise — date_context slot renders empty string."""
|
|
import app.core.deep_agent as da
|
|
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
|
|
|
|
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
|
|
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, {})
|
|
# Prompt renders without error; date section is empty but structure holds
|
|
assert "# Date filtering" in text
|
|
assert "{date_context}" not in text
|
|
|
|
|
|
def test_human_message_is_bare_message(monkeypatch):
|
|
"""After the refactor HumanMessage content must equal the raw user message exactly."""
|
|
import app.core.deep_agent as da
|
|
from langchain_core.messages import HumanMessage as LCHumanMessage
|
|
|
|
captured: list[list] = []
|
|
|
|
class _CaptureLLM:
|
|
def bind_tools(self, _):
|
|
return self
|
|
|
|
async def ainvoke(self, messages):
|
|
captured.append(list(messages))
|
|
return AIMessage(content="risposta")
|
|
|
|
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda n, f: (f, None))
|
|
monkeypatch.setattr(da, "get_agent_llm", lambda _: _CaptureLLM())
|
|
monkeypatch.setattr(da, "_all_tools_for_user", lambda *_: [])
|
|
monkeypatch.setattr(da, "get_langfuse", lambda: None)
|
|
monkeypatch.setattr(da, "set_tool_result_collector", lambda _: None)
|
|
monkeypatch.setattr(da, "clear_tool_result_collector", lambda: None)
|
|
|
|
import asyncio
|
|
|
|
async def _run():
|
|
chunks = []
|
|
ctx = {"format_prefs": _fp("UTC", "2026-04-27T10:00:00Z")}
|
|
async for ev in da.run_home_stream("u1", "Cosa devo fare domani?", ctx):
|
|
chunks.append(ev)
|
|
|
|
asyncio.get_event_loop().run_until_complete(_run())
|
|
|
|
assert captured, "LLM was never called"
|
|
messages = captured[0]
|
|
human = next(m for m in messages if isinstance(m, LCHumanMessage))
|
|
assert human.content == "Cosa devo fare domani?"
|
|
assert "Context:" not in human.content
|