diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py index dd07f10..2ede506 100644 --- a/app/core/deep_agent.py +++ b/app/core/deep_agent.py @@ -236,7 +236,19 @@ _HOME_SYSTEM = ( "multiple sub-agents if needed.\n\n" "You also have an update_core_memory tool — use it when the user states " "a preference or important fact worth remembering long-term.\n\n" - "After gathering data, synthesize a clear, helpful response for the user.\n\n" + "## Entity References\n" + "When your response mentions specific workspace entities, embed them " + "inline using entity tags so the UI can render interactive components.\n" + "Format: [comma-separated UUIDs]\n" + "Supported types: task, project, note, timeline\n\n" + "Example response:\n" + " Here is your project:\n" + " [abc-123-def]\n" + " It has these pending tasks:\n" + " [def-456,ghi-789]\n\n" + "IMPORTANT: Only include IDs of entities that are directly relevant to " + "the user's question. Do NOT dump all entity IDs returned by a tool — " + "filter to only the ones the user asked about or that matter for the answer.\n\n" "Memory context:\n{memory_context}" ) @@ -360,6 +372,7 @@ async def _run_graph_stream( isinstance(msg, AIMessageChunk) and msg.content and not msg.tool_calls + and isinstance(metadata, dict) and metadata.get("langgraph_node") == "agent" ): yield ("token", str(msg.content)) diff --git a/app/core/output_formatter.py b/app/core/output_formatter.py index a5106e3..d6a4833 100644 --- a/app/core/output_formatter.py +++ b/app/core/output_formatter.py @@ -6,9 +6,10 @@ Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``: * ``("mutations", list)`` — collected CRUD mutations for ``stream_end`` HomeFormatter: - * Sniffs ``tool_end`` events → emits ``WsStreamBlock`` (entity_ref with raw data) - * Streams text tokens → emits ``WsStreamText`` - * Attaches mutations → injects into ``WsStreamEnd`` + * Buffers text tokens and parses inline entity tags + ``[id1,id2]`` → emits ``WsStreamBlock`` (entity_ref with IDs) + * Streams surrounding text → emits ``WsStreamText`` + * Attaches mutations → injects into ``WsStreamEnd`` FloatingFormatter: * Sniffs first ``tool_end`` name → emits ``WsFloatingDomain`` @@ -19,6 +20,7 @@ FloatingFormatter: from __future__ import annotations import logging +import re from collections.abc import AsyncGenerator from typing import Any @@ -40,20 +42,91 @@ _AGENT_DOMAIN: dict[str, str] = { "project_agent": "projects", } +# Regex for complete inline entity tags: [id1,id2] +_ENTITY_TAG_RE = re.compile( + r"<(task|project|note|timeline)>\[([^\]]+)\]" +) + +# Tag name → plural entity type for the WsStreamBlock data +_TAG_ENTITY: dict[str, str] = { + "task": "tasks", + "project": "projects", + "note": "notes", + "timeline": "timelines", +} + WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain class HomeFormatter: """Consumes a deep-agent event stream and yields WS frames for the Home view. - ``tool_end`` events from sub-agents are emitted as ``WsStreamBlock`` - (entity_ref) so the client can render structured data. Text tokens are - forwarded as ``WsStreamText``. Mutations are attached to ``WsStreamEnd``. + The supervisor's response contains inline entity tags like + ``[abc-123]``. This formatter detects them, + emits ``WsStreamBlock(block_type="entity_ref")`` with the entity + type and IDs, and forwards surrounding text as ``WsStreamText``. + Mutations are attached to ``WsStreamEnd``. """ def __init__(self, request_id: str) -> None: self.request_id = request_id self._mutations: list[dict] = [] + self._buffer: str = "" + + def _flush_buffer(self, force: bool = False): + """Extract complete entity tags and text from the buffer. + + Yields (frame_type, data) pairs: + ("text", str) — plain text to send as WsStreamText + ("block", dict) — entity_ref block to send as WsStreamBlock + + When *force* is True (end of stream), the entire buffer is flushed. + Otherwise, text after the last unmatched ``<`` is held back in case + it is the start of an entity tag arriving across token boundaries. + """ + buf = self._buffer + + while True: + m = _ENTITY_TAG_RE.search(buf) + if not m: + break + + # Text before the tag + before = buf[: m.start()] + if before: + yield ("text", before) + + # The entity tag itself → a block + tag_type = m.group(1) + raw_ids = m.group(2) + ids = [i.strip() for i in raw_ids.split(",") if i.strip()] + yield ( + "block", + { + "entity": _TAG_ENTITY[tag_type], + "ids": ids, + }, + ) + + buf = buf[m.end() :] + + if force: + # End of stream — flush everything that remains + if buf: + yield ("text", buf) + self._buffer = "" + else: + # Keep a potential partial tag (text after last '<') in the buffer + last_lt = buf.rfind("<") + if last_lt != -1: + safe = buf[:last_lt] + if safe: + yield ("text", safe) + self._buffer = buf[last_lt:] + else: + if buf: + yield ("text", buf) + self._buffer = "" async def format( self, @@ -64,22 +137,36 @@ class HomeFormatter: async for event_type, data in event_stream: if event_type == "token": if data: - yield WsStreamText(request_id=self.request_id, chunk=data) - - elif event_type == "tool_end": - # Sub-agent finished — emit its result as an entity_ref block - name = data.get("name", "") - entity = _AGENT_DOMAIN.get(name) - if entity: - yield WsStreamBlock( - request_id=self.request_id, - block_type="entity_ref", - data={"entity": entity, "result": data.get("result", "")}, - ) + self._buffer += data + for ftype, fdata in self._flush_buffer(): + if ftype == "text": + yield WsStreamText( + request_id=self.request_id, chunk=fdata + ) + elif ftype == "block": + yield WsStreamBlock( + request_id=self.request_id, + block_type="entity_ref", + data=fdata, + ) elif event_type == "mutations": self._mutations = data or [] + # tool_end events are intentionally ignored — the supervisor + # embeds relevant entity IDs inline via [ids] tags. + + # Flush any remaining buffer content + for ftype, fdata in self._flush_buffer(force=True): + if ftype == "text": + yield WsStreamText(request_id=self.request_id, chunk=fdata) + elif ftype == "block": + yield WsStreamBlock( + request_id=self.request_id, + block_type="entity_ref", + data=fdata, + ) + yield WsStreamEnd( request_id=self.request_id, mutations=[ diff --git a/tests/test_output_formatter.py b/tests/test_output_formatter.py index 817f887..cb2d68a 100644 --- a/tests/test_output_formatter.py +++ b/tests/test_output_formatter.py @@ -32,7 +32,7 @@ async def collect(formatter, event_stream): # ── HomeFormatter ───────────────────────────────────────────────────────────── @pytest.mark.asyncio -async def test_home_formatter_text_token(): +async def test_home_formatter_plain_text(): req_id = "req-1" events = [ ("token", "Hello world"), @@ -49,11 +49,11 @@ async def test_home_formatter_text_token(): @pytest.mark.asyncio -async def test_home_formatter_entity_ref_from_tool_end(): +async def test_home_formatter_entity_tag_single_id(): + """A [id] tag emits a WsStreamBlock with entity + ids.""" req_id = "req-2" events = [ - ("tool_end", {"name": "task_agent", "result": "Found 3 tasks."}), - ("token", "Here are your tasks."), + ("token", "Here is your project:\n[abc-123]\nAll good."), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) @@ -62,27 +62,86 @@ async def test_home_formatter_entity_ref_from_tool_end(): block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 assert block_frames[0].block_type == "entity_ref" - assert block_frames[0].data["entity"] == "tasks" - assert block_frames[0].data["result"] == "Found 3 tasks." + assert block_frames[0].data["entity"] == "projects" + assert block_frames[0].data["ids"] == ["abc-123"] + + # Surrounding text is streamed + text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText)) + assert "Here is your project:" in text + assert "All good." in text + # The raw tag itself should NOT appear in streamed text + assert "" not in text @pytest.mark.asyncio -async def test_home_formatter_unknown_agent_no_block(): +async def test_home_formatter_entity_tag_multiple_ids(): req_id = "req-3" events = [ - ("tool_end", {"name": "unknown_agent", "result": "stuff"}), + ("token", "Pending:\n[id-1,id-2,id-3]"), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 0 # unknown agent → no entity mapping + assert len(block_frames) == 1 + assert block_frames[0].data["entity"] == "tasks" + assert block_frames[0].data["ids"] == ["id-1", "id-2", "id-3"] + + +@pytest.mark.asyncio +async def test_home_formatter_multiple_entity_tags(): + req_id = "req-4" + events = [ + ("token", "[p1]\nText\n[t1,t2]"), + ("mutations", []), + ] + formatter = HomeFormatter(request_id=req_id) + frames = await collect(formatter, _stream(*events)) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 2 + entities = {b.data["entity"] for b in block_frames} + assert entities == {"projects", "tasks"} + + +@pytest.mark.asyncio +async def test_home_formatter_tag_split_across_tokens(): + """Entity tag arrives across two token chunks — still detected.""" + req_id = "req-5" + events = [ + ("token", "See: [abc-"), + ("token", "123] done"), + ("mutations", []), + ] + formatter = HomeFormatter(request_id=req_id) + frames = await collect(formatter, _stream(*events)) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 1 + assert block_frames[0].data["entity"] == "projects" + assert block_frames[0].data["ids"] == ["abc-123"] + + +@pytest.mark.asyncio +async def test_home_formatter_tool_end_ignored(): + """tool_end events no longer produce blocks — only entity tags do.""" + req_id = "req-6" + events = [ + ("tool_end", {"name": "task_agent", "result": "3 tasks"}), + ("token", "No tags here."), + ("mutations", []), + ] + formatter = HomeFormatter(request_id=req_id) + frames = await collect(formatter, _stream(*events)) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 0 @pytest.mark.asyncio async def test_home_formatter_mutations_in_stream_end(): - req_id = "req-4" + req_id = "req-7" muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}] events = [ ("token", "Done"), @@ -100,31 +159,13 @@ async def test_home_formatter_mutations_in_stream_end(): @pytest.mark.asyncio async def test_home_formatter_frame_order(): """stream_start is first, stream_end is last.""" - req_id = "req-5" + req_id = "req-8" formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", []))) assert isinstance(frames[0], WsStreamStart) assert isinstance(frames[-1], WsStreamEnd) -@pytest.mark.asyncio -async def test_home_formatter_multiple_tool_ends(): - req_id = "req-6" - events = [ - ("tool_end", {"name": "task_agent", "result": "3 tasks"}), - ("tool_end", {"name": "project_agent", "result": "2 projects"}), - ("token", "Overview done."), - ("mutations", []), - ] - formatter = HomeFormatter(request_id=req_id) - frames = await collect(formatter, _stream(*events)) - - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 2 - entities = {b.data["entity"] for b in block_frames} - assert entities == {"tasks", "projects"} - - # ── FloatingFormatter ───────────────────────────────────────────────────────── @pytest.mark.asyncio diff --git a/tests/test_ws_unified.py b/tests/test_ws_unified.py index c770448..58408c6 100644 --- a/tests/test_ws_unified.py +++ b/tests/test_ws_unified.py @@ -46,8 +46,7 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]: async def _mock_home_stream(user_id, message, context, db_session_factory=None): - yield "tool_end", {"name": "task_agent", "result": "Found tasks"} - yield "token", "Hello" + yield "token", "Here are your tasks:\n[t1,t2]" yield "mutations", [] @@ -115,7 +114,6 @@ def test_home_request_request_id_propagated(client): req_id = "my-unique-req-id" async def _stream(user_id, message, context, db_session_factory=None): - yield "tool_end", {"name": "note_agent", "result": "ok"} yield "token", "ok" yield "mutations", []