refactor floating_domain to structured object-only payload

This commit is contained in:
2026-03-13 16:09:24 +01:00
parent 13fd8677c1
commit 2a0331d7ce
7 changed files with 248 additions and 49 deletions

View File

@@ -23,7 +23,8 @@ from app.db import async_session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"] FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
_HOME_SINGLE_AGENT_SYSTEM = ( _HOME_SINGLE_AGENT_SYSTEM = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. " "You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
@@ -44,8 +45,18 @@ _FLOATING_SINGLE_AGENT_SYSTEM = (
"Always use tools for factual data retrieval before answering. " "Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. " "When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, " )
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
"Allowed type values: task, timeline, project, node. "
"Allowed section values: task, timeline, note, or null. "
"Rules: infer from user message intent first; do not blindly trust scope.type. "
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
"If project id is unknown but context.resolved_project_id exists, use it as id. "
"If id is unknown, use null. "
"No markdown, no prose, JSON only."
) )
@@ -347,27 +358,145 @@ def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)] return [*_all_tools(), *_memory_tools(user_id, trace_id)]
def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain: def _detect_domain_section(message: str) -> FloatingDomainSection | None:
scope = context.get("scope") if isinstance(context, dict) else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
if scope_type in {"task", "tasks"}:
return "tasks"
if scope_type in {"project", "projects"}:
return "projects"
if scope_type in {"note", "notes"}:
return "notes"
if scope_type in {"timeline", "timelines"}:
return "timelines"
lowered = message.lower() lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]): if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timelines" return "timeline"
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
return "task"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]): if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "notes" return "note"
if any(keyword in lowered for keyword in ["project", "progetto", "client"]): return None
return "projects"
return "tasks"
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
type_raw = str(payload.get("type") or "").strip().lower()
domain_type: FloatingDomainType = "task"
if type_raw in {"task", "timeline", "project", "node"}:
domain_type = type_raw
id_value = payload.get("id")
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
if domain_type == "project" and not domain_id:
domain_id = fallback_id
section_raw = payload.get("section")
section: FloatingDomainSection | None = None
if isinstance(section_raw, str):
section_candidate = section_raw.strip().lower()
if section_candidate in {"task", "timeline", "note"}:
section = section_candidate
if domain_type != "project":
section = None
return {
"type": domain_type,
"id": domain_id,
"section": section,
}
def _parse_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", raw, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
section = _detect_domain_section(message)
scope = context.get("scope") if isinstance(context, dict) else None
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
scope_id = scope.get("id")
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
if scope_type in {"task", "tasks"}:
return {"type": "task", "id": scope_id_value, "section": None}
if scope_type in {"project", "projects"}:
project_scope_id = scope_id_value or project_id
return {
"type": "project",
"id": project_scope_id,
"section": section,
}
if scope_type in {"note", "notes"}:
return {
"type": "node",
"id": scope_id_value,
"section": None,
}
if scope_type in {"timeline", "timelines"}:
return {"type": "timeline", "id": scope_id_value, "section": None}
lowered = message.lower()
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
return {
"type": "project",
"id": project_id,
"section": section,
}
if section == "timeline":
return {"type": "timeline", "id": None, "section": None}
if section == "note":
return {"type": "node", "id": None, "section": None}
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
classifier_context = {
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
"resolved_project_id": project_id,
}
try:
llm = get_llm()
response = await llm.ainvoke(
[
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
)
),
]
)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
logger.info(
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
domain.get("type"),
domain.get("id"),
domain.get("section"),
)
return domain
logger.warning("deep_agent: floating_domain classifier returned non-json output")
except Exception as exc:
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
return _infer_floating_domain_rule_based(message, context)
async def _run_single_agent( async def _run_single_agent(
@@ -558,9 +687,9 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
return _normalize_tagged_list_lines(response, message) return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]: async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
domain = _infer_floating_domain(message, context)
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
response = await _run_single_agent( response = await _run_single_agent(
user_id=user_id, user_id=user_id,
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
@@ -599,10 +728,10 @@ async def run_floating_stream(
message: str, message: str,
context: dict[str, Any], context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
domain = _infer_floating_domain(message, context) prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain yield "floating_domain", domain
prepared_context = await _prepare_context(message, context)
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,

View File

@@ -24,7 +24,11 @@ class StreamFormatter:
async for event_type, data in event_stream: async for event_type, data in event_stream:
if event_type == "floating_domain": if event_type == "floating_domain":
yield WsFloatingDomain(request_id=self.request_id, domain=str(data)) if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain=data,
)
continue continue
if event_type != "token": if event_type != "token":

View File

@@ -281,12 +281,20 @@ class WsStreamEnd(BaseModel):
request_id: str request_id: str
class WsDomain(BaseModel):
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
class WsFloatingDomain(BaseModel): class WsFloatingDomain(BaseModel):
"""Server → Client: domain determined for a floating request.""" """Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str request_id: str
domain: Literal["tasks", "timelines", "notes", "projects"] domain: WsDomain
# ── Agent Catalog ───────────────────────────────────────────────────── # ── Agent Catalog ─────────────────────────────────────────────────────

View File

@@ -9,7 +9,7 @@ from unittest.mock import patch
import pytest import pytest
from langchain_core.messages import AIMessage, ToolMessage from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import _normalize_tagged_list_lines, run_floating_stream, run_home from app.core.deep_agent import _infer_floating_domain, _normalize_tagged_list_lines, run_floating_stream, run_home
class _FakeTool: class _FakeTool:
@@ -21,14 +21,18 @@ class _FakeTool:
class _FakeLLM: class _FakeLLM:
def __init__(self) -> None: def __init__(self) -> None:
self.calls = 0 self.agent_calls = 0
def bind_tools(self, _tools): def bind_tools(self, _tools):
return self return self
async def ainvoke(self, messages): async def ainvoke(self, messages):
self.calls += 1 system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
if self.calls == 1: if "strict domain classifier" in system_prompt:
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
self.agent_calls += 1
if self.agent_calls == 1:
return AIMessage( return AIMessage(
content="", content="",
tool_calls=[ tool_calls=[
@@ -77,11 +81,38 @@ async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_res
): ):
events.append(event) events.append(event)
assert events[0] == ("floating_domain", "timelines") assert events[0] == (
"floating_domain",
{"type": "timeline", "id": "tl-1", "section": None},
)
assert ("token", "stream-") in events assert ("token", "stream-") in events
assert ("token", "ok") in events assert ("token", "ok") in events
@pytest.mark.asyncio
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
class _ClassifierOnlyLLM:
async def ainvoke(self, _messages):
return AIMessage(
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
)
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X",
{
"scope": {"type": "timeline"},
"resolved_project_id": "213213-312321-312312-421321",
},
)
assert domain == {
"type": "project",
"id": "213213-312321-312312-421321",
"section": "task",
}
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines(): def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
raw = ( raw = (
"Certo!\n\n" "Certo!\n\n"

View File

@@ -41,11 +41,18 @@ async def test_stream_formatter_floating_domain_first() -> None:
formatter = StreamFormatter(request_id="req-2") formatter = StreamFormatter(request_id="req-2")
frames = await _collect( frames = await _collect(
formatter, formatter,
_stream(("floating_domain", "notes"), ("token", "Summary")), _stream(
(
"floating_domain",
{"type": "node", "id": "n-1", "section": None},
),
("token", "Summary"),
),
) )
assert isinstance(frames[0], WsFloatingDomain) assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "notes" assert frames[0].domain.type == "node"
assert frames[0].domain.id == "n-1"
assert isinstance(frames[1], WsStreamStart) assert isinstance(frames[1], WsStreamStart)
assert isinstance(frames[2], WsStreamText) assert isinstance(frames[2], WsStreamText)
assert frames[2].chunk == "Summary" assert frames[2].chunk == "Summary"

View File

@@ -4,6 +4,7 @@ import pytest
from pydantic import ValidationError from pydantic import ValidationError
from app.schemas import ( from app.schemas import (
WsDomain,
WsFrameType, WsFrameType,
WsHomeRequest, WsHomeRequest,
WsFloatingDomain, WsFloatingDomain,
@@ -195,28 +196,47 @@ def test_stream_end_deserializes():
def test_floating_domain_tasks(): def test_floating_domain_tasks():
frame = WsFloatingDomain(request_id="r1", domain="tasks") frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
assert frame.type == WsFrameType.floating_domain assert frame.type == WsFrameType.floating_domain
assert frame.domain == "tasks" assert frame.domain.type == "task"
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"]) def test_floating_domain_valid_domains():
def test_floating_domain_valid_domains(domain: str): frame = WsFloatingDomain(
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type] request_id="r1",
assert frame.domain == domain domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
)
assert frame.domain.type == "project"
assert frame.domain.id == "213213-312321-312312-421321"
assert frame.domain.section == "task"
def test_floating_domain_invalid(): def test_floating_domain_object_valid():
with pytest.raises(ValidationError): frame = WsFloatingDomain(
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type] request_id="r1",
domain=WsDomain(type="project", id="p1", section="task"),
)
assert frame.domain.type == "project"
def test_floating_domain_serializes(): def test_floating_domain_serializes():
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump() d = WsFloatingDomain(
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"} request_id="r1",
domain=WsDomain(type="timeline"),
).model_dump()
assert d == {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "timeline", "id": None, "section": None},
}
def test_floating_domain_deserializes(): def test_floating_domain_deserializes():
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"} raw = {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "node", "id": "n-1", "section": None},
}
frame = WsFloatingDomain.model_validate(raw) frame = WsFloatingDomain.model_validate(raw)
assert frame.domain == "projects" assert frame.domain.type == "node"
assert frame.domain.id == "n-1"

View File

@@ -50,7 +50,7 @@ async def _mock_home_stream(user_id, message, context):
async def _mock_floating_stream(user_id, message, context): async def _mock_floating_stream(user_id, message, context):
yield "floating_domain", "tasks" yield "floating_domain", {"type": "task", "id": None, "section": None}
yield "token", "Here is a summary" yield "token", "Here is a summary"
@@ -102,7 +102,7 @@ def test_floating_request_produces_domain_frame(client):
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end) assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain) domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"] == "tasks" assert domain_frame["domain"]["type"] == "task"
assert domain_frame["request_id"] == "p1" assert domain_frame["request_id"] == "p1"