diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py index ad34767..ac6957e 100644 --- a/app/core/deep_agent.py +++ b/app/core/deep_agent.py @@ -23,7 +23,8 @@ from app.db import async_session logger = logging.getLogger(__name__) -FloatingDomain = Literal["tasks", "projects", "notes", "timelines"] +FloatingDomainType = Literal["task", "timeline", "project", "node"] +FloatingDomainSection = Literal["task", "timeline", "note"] _HOME_SINGLE_AGENT_SYSTEM = ( "You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. " @@ -44,8 +45,18 @@ _FLOATING_SINGLE_AGENT_SYSTEM = ( "Always use tools for factual data retrieval before answering. " "When the user asks to remember, forget, or update what you know about them, use memory tools. " "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " - "Return markdown and embed inline tags when relevant: [ids], [ids], " - "[ids], [ids], {json}." +) + +_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = ( + "You are a strict domain classifier for websocket floating requests. " + "Return ONLY a JSON object with keys: type, id, section. " + "Allowed type values: task, timeline, project, node. " + "Allowed section values: task, timeline, note, or null. " + "Rules: infer from user message intent first; do not blindly trust scope.type. " + "If user asks tasks/timeline/notes for a project, set type=project and section accordingly. " + "If project id is unknown but context.resolved_project_id exists, use it as id. " + "If id is unknown, use null. " + "No markdown, no prose, JSON only." ) @@ -347,27 +358,145 @@ def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]: return [*_all_tools(), *_memory_tools(user_id, trace_id)] -def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain: - scope = context.get("scope") if isinstance(context, dict) else None - if isinstance(scope, dict): - scope_type = str(scope.get("type") or "").strip().lower() - if scope_type in {"task", "tasks"}: - return "tasks" - if scope_type in {"project", "projects"}: - return "projects" - if scope_type in {"note", "notes"}: - return "notes" - if scope_type in {"timeline", "timelines"}: - return "timelines" - +def _detect_domain_section(message: str) -> FloatingDomainSection | None: lowered = message.lower() if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]): - return "timelines" + return "timeline" + if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]): + return "task" if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]): - return "notes" - if any(keyword in lowered for keyword in ["project", "progetto", "client"]): - return "projects" - return "tasks" + return "note" + return None + + +def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]: + type_raw = str(payload.get("type") or "").strip().lower() + domain_type: FloatingDomainType = "task" + if type_raw in {"task", "timeline", "project", "node"}: + domain_type = type_raw + + id_value = payload.get("id") + domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None + if domain_type == "project" and not domain_id: + domain_id = fallback_id + + section_raw = payload.get("section") + section: FloatingDomainSection | None = None + if isinstance(section_raw, str): + section_candidate = section_raw.strip().lower() + if section_candidate in {"task", "timeline", "note"}: + section = section_candidate + + if domain_type != "project": + section = None + + return { + "type": domain_type, + "id": domain_id, + "section": section, + } + + +def _parse_json_object(text: str) -> dict[str, Any] | None: + raw = text.strip() + if not raw: + return None + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, dict) else None + except json.JSONDecodeError: + pass + + match = re.search(r"\{.*\}", raw, re.DOTALL) + if not match: + return None + try: + parsed = json.loads(match.group(0)) + except json.JSONDecodeError: + return None + return parsed if isinstance(parsed, dict) else None + + +def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]: + section = _detect_domain_section(message) + scope = context.get("scope") if isinstance(context, dict) else None + resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None + project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None + + if isinstance(scope, dict): + scope_type = str(scope.get("type") or "").strip().lower() + scope_id = scope.get("id") + scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None + + if scope_type in {"task", "tasks"}: + return {"type": "task", "id": scope_id_value, "section": None} + if scope_type in {"project", "projects"}: + project_scope_id = scope_id_value or project_id + return { + "type": "project", + "id": project_scope_id, + "section": section, + } + if scope_type in {"note", "notes"}: + return { + "type": "node", + "id": scope_id_value, + "section": None, + } + if scope_type in {"timeline", "timelines"}: + return {"type": "timeline", "id": scope_id_value, "section": None} + + lowered = message.lower() + if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id: + return { + "type": "project", + "id": project_id, + "section": section, + } + if section == "timeline": + return {"type": "timeline", "id": None, "section": None} + if section == "note": + return {"type": "node", "id": None, "section": None} + return {"type": "task", "id": None, "section": None} + + +async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]: + resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None + project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None + + classifier_context = { + "scope": context.get("scope") if isinstance(context.get("scope"), dict) else None, + "resolved_project_id": project_id, + } + + try: + llm = get_llm() + response = await llm.ainvoke( + [ + SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM), + HumanMessage( + content=( + f"Message:\n{message}\n\n" + f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}" + ) + ), + ] + ) + parsed = _parse_json_object(_as_text(response.content)) + if parsed is not None: + domain = _normalize_domain_payload(parsed, project_id) + logger.info( + "deep_agent: floating_domain_classified type=%s id=%s section=%s", + domain.get("type"), + domain.get("id"), + domain.get("section"), + ) + return domain + logger.warning("deep_agent: floating_domain classifier returned non-json output") + except Exception as exc: + logger.warning("deep_agent: floating_domain classifier failed: %s", exc) + + return _infer_floating_domain_rule_based(message, context) async def _run_single_agent( @@ -558,9 +687,9 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str: return _normalize_tagged_list_lines(response, message) -async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]: - domain = _infer_floating_domain(message, context) +async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]: prepared_context = await _prepare_context(message, context) + domain = await _infer_floating_domain(message, prepared_context) response = await _run_single_agent( user_id=user_id, system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, @@ -599,10 +728,10 @@ async def run_floating_stream( message: str, context: dict[str, Any], ) -> AsyncGenerator[tuple[str, Any], None]: - domain = _infer_floating_domain(message, context) + prepared_context = await _prepare_context(message, context) + domain = await _infer_floating_domain(message, prepared_context) yield "floating_domain", domain - prepared_context = await _prepare_context(message, context) async for event in _run_single_agent_stream( user_id=user_id, system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, diff --git a/app/core/output_formatter.py b/app/core/output_formatter.py index 429a2ce..3c6f6df 100644 --- a/app/core/output_formatter.py +++ b/app/core/output_formatter.py @@ -24,7 +24,11 @@ class StreamFormatter: async for event_type, data in event_stream: if event_type == "floating_domain": - yield WsFloatingDomain(request_id=self.request_id, domain=str(data)) + if isinstance(data, dict): + yield WsFloatingDomain( + request_id=self.request_id, + domain=data, + ) continue if event_type != "token": diff --git a/app/schemas.py b/app/schemas.py index 3005169..3f0d227 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -281,12 +281,20 @@ class WsStreamEnd(BaseModel): request_id: str +class WsDomain(BaseModel): + """Structured floating domain payload for UI routing decisions.""" + + type: Literal["task", "timeline", "project", "node"] + id: str | None = None + section: Literal["task", "timeline", "note"] | None = None + + class WsFloatingDomain(BaseModel): """Server → Client: domain determined for a floating request.""" type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain request_id: str - domain: Literal["tasks", "timelines", "notes", "projects"] + domain: WsDomain # ── Agent Catalog ───────────────────────────────────────────────────── diff --git a/tests/test_deep_agent.py b/tests/test_deep_agent.py index 729eedc..8069aa0 100644 --- a/tests/test_deep_agent.py +++ b/tests/test_deep_agent.py @@ -9,7 +9,7 @@ from unittest.mock import patch import pytest from langchain_core.messages import AIMessage, ToolMessage -from app.core.deep_agent import _normalize_tagged_list_lines, run_floating_stream, run_home +from app.core.deep_agent import _infer_floating_domain, _normalize_tagged_list_lines, run_floating_stream, run_home class _FakeTool: @@ -21,14 +21,18 @@ class _FakeTool: class _FakeLLM: def __init__(self) -> None: - self.calls = 0 + self.agent_calls = 0 def bind_tools(self, _tools): return self async def ainvoke(self, messages): - self.calls += 1 - if self.calls == 1: + system_prompt = str(getattr(messages[0], "content", "")) if messages else "" + if "strict domain classifier" in system_prompt: + return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}') + + self.agent_calls += 1 + if self.agent_calls == 1: return AIMessage( content="", tool_calls=[ @@ -77,11 +81,38 @@ async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_res ): events.append(event) - assert events[0] == ("floating_domain", "timelines") + assert events[0] == ( + "floating_domain", + {"type": "timeline", "id": "tl-1", "section": None}, + ) assert ("token", "stream-") in events assert ("token", "ok") in events +@pytest.mark.asyncio +async def test_infer_floating_domain_prefers_message_intent_over_scope_type(): + class _ClassifierOnlyLLM: + async def ainvoke(self, _messages): + return AIMessage( + content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}' + ) + + with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()): + domain = await _infer_floating_domain( + "Quali sono i miei task per il progetto X", + { + "scope": {"type": "timeline"}, + "resolved_project_id": "213213-312321-312312-421321", + }, + ) + + assert domain == { + "type": "project", + "id": "213213-312321-312312-421321", + "section": "task", + } + + def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines(): raw = ( "Certo!\n\n" diff --git a/tests/test_output_formatter.py b/tests/test_output_formatter.py index 2f06f79..b9b6741 100644 --- a/tests/test_output_formatter.py +++ b/tests/test_output_formatter.py @@ -41,11 +41,18 @@ async def test_stream_formatter_floating_domain_first() -> None: formatter = StreamFormatter(request_id="req-2") frames = await _collect( formatter, - _stream(("floating_domain", "notes"), ("token", "Summary")), + _stream( + ( + "floating_domain", + {"type": "node", "id": "n-1", "section": None}, + ), + ("token", "Summary"), + ), ) assert isinstance(frames[0], WsFloatingDomain) - assert frames[0].domain == "notes" + assert frames[0].domain.type == "node" + assert frames[0].domain.id == "n-1" assert isinstance(frames[1], WsStreamStart) assert isinstance(frames[2], WsStreamText) assert frames[2].chunk == "Summary" diff --git a/tests/test_schemas_v3.py b/tests/test_schemas_v3.py index 16dc611..a354ca3 100644 --- a/tests/test_schemas_v3.py +++ b/tests/test_schemas_v3.py @@ -4,6 +4,7 @@ import pytest from pydantic import ValidationError from app.schemas import ( + WsDomain, WsFrameType, WsHomeRequest, WsFloatingDomain, @@ -195,28 +196,47 @@ def test_stream_end_deserializes(): def test_floating_domain_tasks(): - frame = WsFloatingDomain(request_id="r1", domain="tasks") + frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task")) assert frame.type == WsFrameType.floating_domain - assert frame.domain == "tasks" + assert frame.domain.type == "task" -@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"]) -def test_floating_domain_valid_domains(domain: str): - frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type] - assert frame.domain == domain +def test_floating_domain_valid_domains(): + frame = WsFloatingDomain( + request_id="r1", + domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"), + ) + assert frame.domain.type == "project" + assert frame.domain.id == "213213-312321-312312-421321" + assert frame.domain.section == "task" -def test_floating_domain_invalid(): - with pytest.raises(ValidationError): - WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type] +def test_floating_domain_object_valid(): + frame = WsFloatingDomain( + request_id="r1", + domain=WsDomain(type="project", id="p1", section="task"), + ) + assert frame.domain.type == "project" def test_floating_domain_serializes(): - d = WsFloatingDomain(request_id="r1", domain="notes").model_dump() - assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"} + d = WsFloatingDomain( + request_id="r1", + domain=WsDomain(type="timeline"), + ).model_dump() + assert d == { + "type": "floating_domain", + "request_id": "r1", + "domain": {"type": "timeline", "id": None, "section": None}, + } def test_floating_domain_deserializes(): - raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"} + raw = { + "type": "floating_domain", + "request_id": "r1", + "domain": {"type": "node", "id": "n-1", "section": None}, + } frame = WsFloatingDomain.model_validate(raw) - assert frame.domain == "projects" + assert frame.domain.type == "node" + assert frame.domain.id == "n-1" diff --git a/tests/test_ws_unified.py b/tests/test_ws_unified.py index 41fd689..2af4364 100644 --- a/tests/test_ws_unified.py +++ b/tests/test_ws_unified.py @@ -50,7 +50,7 @@ async def _mock_home_stream(user_id, message, context): async def _mock_floating_stream(user_id, message, context): - yield "floating_domain", "tasks" + yield "floating_domain", {"type": "task", "id": None, "section": None} yield "token", "Here is a summary" @@ -102,7 +102,7 @@ def test_floating_request_produces_domain_frame(client): assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end) domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain) - assert domain_frame["domain"] == "tasks" + assert domain_frame["domain"]["type"] == "task" assert domain_frame["request_id"] == "p1"