refactor floating_domain to structured object-only payload

This commit is contained in:
2026-03-13 16:09:24 +01:00
parent 13fd8677c1
commit 2a0331d7ce
7 changed files with 248 additions and 49 deletions

View File

@@ -9,7 +9,7 @@ from unittest.mock import patch
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import _normalize_tagged_list_lines, run_floating_stream, run_home
from app.core.deep_agent import _infer_floating_domain, _normalize_tagged_list_lines, run_floating_stream, run_home
class _FakeTool:
@@ -21,14 +21,18 @@ class _FakeTool:
class _FakeLLM:
def __init__(self) -> None:
self.calls = 0
self.agent_calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, messages):
self.calls += 1
if self.calls == 1:
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
if "strict domain classifier" in system_prompt:
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
self.agent_calls += 1
if self.agent_calls == 1:
return AIMessage(
content="",
tool_calls=[
@@ -77,11 +81,38 @@ async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_res
):
events.append(event)
assert events[0] == ("floating_domain", "timelines")
assert events[0] == (
"floating_domain",
{"type": "timeline", "id": "tl-1", "section": None},
)
assert ("token", "stream-") in events
assert ("token", "ok") in events
@pytest.mark.asyncio
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
class _ClassifierOnlyLLM:
async def ainvoke(self, _messages):
return AIMessage(
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
)
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X",
{
"scope": {"type": "timeline"},
"resolved_project_id": "213213-312321-312312-421321",
},
)
assert domain == {
"type": "project",
"id": "213213-312321-312312-421321",
"section": "task",
}
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
raw = (
"Certo!\n\n"

View File

@@ -41,11 +41,18 @@ async def test_stream_formatter_floating_domain_first() -> None:
formatter = StreamFormatter(request_id="req-2")
frames = await _collect(
formatter,
_stream(("floating_domain", "notes"), ("token", "Summary")),
_stream(
(
"floating_domain",
{"type": "node", "id": "n-1", "section": None},
),
("token", "Summary"),
),
)
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "notes"
assert frames[0].domain.type == "node"
assert frames[0].domain.id == "n-1"
assert isinstance(frames[1], WsStreamStart)
assert isinstance(frames[2], WsStreamText)
assert frames[2].chunk == "Summary"

View File

@@ -4,6 +4,7 @@ import pytest
from pydantic import ValidationError
from app.schemas import (
WsDomain,
WsFrameType,
WsHomeRequest,
WsFloatingDomain,
@@ -195,28 +196,47 @@ def test_stream_end_deserializes():
def test_floating_domain_tasks():
frame = WsFloatingDomain(request_id="r1", domain="tasks")
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
assert frame.type == WsFrameType.floating_domain
assert frame.domain == "tasks"
assert frame.domain.type == "task"
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
def test_floating_domain_valid_domains(domain: str):
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
assert frame.domain == domain
def test_floating_domain_valid_domains():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
)
assert frame.domain.type == "project"
assert frame.domain.id == "213213-312321-312312-421321"
assert frame.domain.section == "task"
def test_floating_domain_invalid():
with pytest.raises(ValidationError):
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
def test_floating_domain_object_valid():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="p1", section="task"),
)
assert frame.domain.type == "project"
def test_floating_domain_serializes():
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
d = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="timeline"),
).model_dump()
assert d == {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "timeline", "id": None, "section": None},
}
def test_floating_domain_deserializes():
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
raw = {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "node", "id": "n-1", "section": None},
}
frame = WsFloatingDomain.model_validate(raw)
assert frame.domain == "projects"
assert frame.domain.type == "node"
assert frame.domain.id == "n-1"

View File

@@ -50,7 +50,7 @@ async def _mock_home_stream(user_id, message, context):
async def _mock_floating_stream(user_id, message, context):
yield "floating_domain", "tasks"
yield "floating_domain", {"type": "task", "id": None, "section": None}
yield "token", "Here is a summary"
@@ -102,7 +102,7 @@ def test_floating_request_produces_domain_frame(client):
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"] == "tasks"
assert domain_frame["domain"]["type"] == "task"
assert domain_frame["request_id"] == "p1"