refactor floating_domain to structured object-only payload
This commit is contained in:
@@ -9,7 +9,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.core.deep_agent import _normalize_tagged_list_lines, run_floating_stream, run_home
|
||||
from app.core.deep_agent import _infer_floating_domain, _normalize_tagged_list_lines, run_floating_stream, run_home
|
||||
|
||||
|
||||
class _FakeTool:
|
||||
@@ -21,14 +21,18 @@ class _FakeTool:
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
self.agent_calls = 0
|
||||
|
||||
def bind_tools(self, _tools):
|
||||
return self
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
|
||||
if "strict domain classifier" in system_prompt:
|
||||
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
|
||||
|
||||
self.agent_calls += 1
|
||||
if self.agent_calls == 1:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
@@ -77,11 +81,38 @@ async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_res
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert events[0] == ("floating_domain", "timelines")
|
||||
assert events[0] == (
|
||||
"floating_domain",
|
||||
{"type": "timeline", "id": "tl-1", "section": None},
|
||||
)
|
||||
assert ("token", "stream-") in events
|
||||
assert ("token", "ok") in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
||||
class _ClassifierOnlyLLM:
|
||||
async def ainvoke(self, _messages):
|
||||
return AIMessage(
|
||||
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
||||
)
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
||||
domain = await _infer_floating_domain(
|
||||
"Quali sono i miei task per il progetto X",
|
||||
{
|
||||
"scope": {"type": "timeline"},
|
||||
"resolved_project_id": "213213-312321-312312-421321",
|
||||
},
|
||||
)
|
||||
|
||||
assert domain == {
|
||||
"type": "project",
|
||||
"id": "213213-312321-312312-421321",
|
||||
"section": "task",
|
||||
}
|
||||
|
||||
|
||||
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
||||
raw = (
|
||||
"Certo!\n\n"
|
||||
|
||||
Reference in New Issue
Block a user