feat: HomeFormatter parses inline entity tags instead of tool_end blocks
The supervisor LLM now embeds <type>[id1,id2]</type> entity tags in its response text. The HomeFormatter buffers streamed tokens, detects complete tags across chunk boundaries, and emits WsStreamBlock with entity type + specific IDs. This replaces the old approach of emitting blocks for every tool_end event, which dumped ALL entities regardless of relevance. Also fixes: - NoneType guard on metadata in _run_graph_stream (metadata can be None) - Updated _HOME_SYSTEM prompt with entity tag instructions - Updated all affected tests
This commit is contained in:
@@ -236,7 +236,19 @@ _HOME_SYSTEM = (
|
|||||||
"multiple sub-agents if needed.\n\n"
|
"multiple sub-agents if needed.\n\n"
|
||||||
"You also have an update_core_memory tool — use it when the user states "
|
"You also have an update_core_memory tool — use it when the user states "
|
||||||
"a preference or important fact worth remembering long-term.\n\n"
|
"a preference or important fact worth remembering long-term.\n\n"
|
||||||
"After gathering data, synthesize a clear, helpful response for the user.\n\n"
|
"## Entity References\n"
|
||||||
|
"When your response mentions specific workspace entities, embed them "
|
||||||
|
"inline using entity tags so the UI can render interactive components.\n"
|
||||||
|
"Format: <type>[comma-separated UUIDs]</type>\n"
|
||||||
|
"Supported types: task, project, note, timeline\n\n"
|
||||||
|
"Example response:\n"
|
||||||
|
" Here is your project:\n"
|
||||||
|
" <project>[abc-123-def]</project>\n"
|
||||||
|
" It has these pending tasks:\n"
|
||||||
|
" <task>[def-456,ghi-789]</task>\n\n"
|
||||||
|
"IMPORTANT: Only include IDs of entities that are directly relevant to "
|
||||||
|
"the user's question. Do NOT dump all entity IDs returned by a tool — "
|
||||||
|
"filter to only the ones the user asked about or that matter for the answer.\n\n"
|
||||||
"Memory context:\n{memory_context}"
|
"Memory context:\n{memory_context}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,6 +372,7 @@ async def _run_graph_stream(
|
|||||||
isinstance(msg, AIMessageChunk)
|
isinstance(msg, AIMessageChunk)
|
||||||
and msg.content
|
and msg.content
|
||||||
and not msg.tool_calls
|
and not msg.tool_calls
|
||||||
|
and isinstance(metadata, dict)
|
||||||
and metadata.get("langgraph_node") == "agent"
|
and metadata.get("langgraph_node") == "agent"
|
||||||
):
|
):
|
||||||
yield ("token", str(msg.content))
|
yield ("token", str(msg.content))
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
|||||||
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
||||||
|
|
||||||
HomeFormatter:
|
HomeFormatter:
|
||||||
* Sniffs ``tool_end`` events → emits ``WsStreamBlock`` (entity_ref with raw data)
|
* Buffers text tokens and parses inline entity tags
|
||||||
* Streams text tokens → emits ``WsStreamText``
|
``<type>[id1,id2]</type>`` → emits ``WsStreamBlock`` (entity_ref with IDs)
|
||||||
* Attaches mutations → injects into ``WsStreamEnd``
|
* Streams surrounding text → emits ``WsStreamText``
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
|
||||||
FloatingFormatter:
|
FloatingFormatter:
|
||||||
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
||||||
@@ -19,6 +20,7 @@ FloatingFormatter:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -40,20 +42,91 @@ _AGENT_DOMAIN: dict[str, str] = {
|
|||||||
"project_agent": "projects",
|
"project_agent": "projects",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Regex for complete inline entity tags: <task>[id1,id2]</task>
|
||||||
|
_ENTITY_TAG_RE = re.compile(
|
||||||
|
r"<(task|project|note|timeline)>\[([^\]]+)\]</\1>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tag name → plural entity type for the WsStreamBlock data
|
||||||
|
_TAG_ENTITY: dict[str, str] = {
|
||||||
|
"task": "tasks",
|
||||||
|
"project": "projects",
|
||||||
|
"note": "notes",
|
||||||
|
"timeline": "timelines",
|
||||||
|
}
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
|
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
class HomeFormatter:
|
class HomeFormatter:
|
||||||
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
||||||
|
|
||||||
``tool_end`` events from sub-agents are emitted as ``WsStreamBlock``
|
The supervisor's response contains inline entity tags like
|
||||||
(entity_ref) so the client can render structured data. Text tokens are
|
``<project>[abc-123]</project>``. This formatter detects them,
|
||||||
forwarded as ``WsStreamText``. Mutations are attached to ``WsStreamEnd``.
|
emits ``WsStreamBlock(block_type="entity_ref")`` with the entity
|
||||||
|
type and IDs, and forwards surrounding text as ``WsStreamText``.
|
||||||
|
Mutations are attached to ``WsStreamEnd``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self._mutations: list[dict] = []
|
self._mutations: list[dict] = []
|
||||||
|
self._buffer: str = ""
|
||||||
|
|
||||||
|
def _flush_buffer(self, force: bool = False):
|
||||||
|
"""Extract complete entity tags and text from the buffer.
|
||||||
|
|
||||||
|
Yields (frame_type, data) pairs:
|
||||||
|
("text", str) — plain text to send as WsStreamText
|
||||||
|
("block", dict) — entity_ref block to send as WsStreamBlock
|
||||||
|
|
||||||
|
When *force* is True (end of stream), the entire buffer is flushed.
|
||||||
|
Otherwise, text after the last unmatched ``<`` is held back in case
|
||||||
|
it is the start of an entity tag arriving across token boundaries.
|
||||||
|
"""
|
||||||
|
buf = self._buffer
|
||||||
|
|
||||||
|
while True:
|
||||||
|
m = _ENTITY_TAG_RE.search(buf)
|
||||||
|
if not m:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Text before the tag
|
||||||
|
before = buf[: m.start()]
|
||||||
|
if before:
|
||||||
|
yield ("text", before)
|
||||||
|
|
||||||
|
# The entity tag itself → a block
|
||||||
|
tag_type = m.group(1)
|
||||||
|
raw_ids = m.group(2)
|
||||||
|
ids = [i.strip() for i in raw_ids.split(",") if i.strip()]
|
||||||
|
yield (
|
||||||
|
"block",
|
||||||
|
{
|
||||||
|
"entity": _TAG_ENTITY[tag_type],
|
||||||
|
"ids": ids,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
buf = buf[m.end() :]
|
||||||
|
|
||||||
|
if force:
|
||||||
|
# End of stream — flush everything that remains
|
||||||
|
if buf:
|
||||||
|
yield ("text", buf)
|
||||||
|
self._buffer = ""
|
||||||
|
else:
|
||||||
|
# Keep a potential partial tag (text after last '<') in the buffer
|
||||||
|
last_lt = buf.rfind("<")
|
||||||
|
if last_lt != -1:
|
||||||
|
safe = buf[:last_lt]
|
||||||
|
if safe:
|
||||||
|
yield ("text", safe)
|
||||||
|
self._buffer = buf[last_lt:]
|
||||||
|
else:
|
||||||
|
if buf:
|
||||||
|
yield ("text", buf)
|
||||||
|
self._buffer = ""
|
||||||
|
|
||||||
async def format(
|
async def format(
|
||||||
self,
|
self,
|
||||||
@@ -64,22 +137,36 @@ class HomeFormatter:
|
|||||||
async for event_type, data in event_stream:
|
async for event_type, data in event_stream:
|
||||||
if event_type == "token":
|
if event_type == "token":
|
||||||
if data:
|
if data:
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=data)
|
self._buffer += data
|
||||||
|
for ftype, fdata in self._flush_buffer():
|
||||||
elif event_type == "tool_end":
|
if ftype == "text":
|
||||||
# Sub-agent finished — emit its result as an entity_ref block
|
yield WsStreamText(
|
||||||
name = data.get("name", "")
|
request_id=self.request_id, chunk=fdata
|
||||||
entity = _AGENT_DOMAIN.get(name)
|
)
|
||||||
if entity:
|
elif ftype == "block":
|
||||||
yield WsStreamBlock(
|
yield WsStreamBlock(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
block_type="entity_ref",
|
block_type="entity_ref",
|
||||||
data={"entity": entity, "result": data.get("result", "")},
|
data=fdata,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif event_type == "mutations":
|
elif event_type == "mutations":
|
||||||
self._mutations = data or []
|
self._mutations = data or []
|
||||||
|
|
||||||
|
# tool_end events are intentionally ignored — the supervisor
|
||||||
|
# embeds relevant entity IDs inline via <type>[ids]</type> tags.
|
||||||
|
|
||||||
|
# Flush any remaining buffer content
|
||||||
|
for ftype, fdata in self._flush_buffer(force=True):
|
||||||
|
if ftype == "text":
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=fdata)
|
||||||
|
elif ftype == "block":
|
||||||
|
yield WsStreamBlock(
|
||||||
|
request_id=self.request_id,
|
||||||
|
block_type="entity_ref",
|
||||||
|
data=fdata,
|
||||||
|
)
|
||||||
|
|
||||||
yield WsStreamEnd(
|
yield WsStreamEnd(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
mutations=[
|
mutations=[
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ async def collect(formatter, event_stream):
|
|||||||
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_text_token():
|
async def test_home_formatter_plain_text():
|
||||||
req_id = "req-1"
|
req_id = "req-1"
|
||||||
events = [
|
events = [
|
||||||
("token", "Hello world"),
|
("token", "Hello world"),
|
||||||
@@ -49,11 +49,11 @@ async def test_home_formatter_text_token():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_entity_ref_from_tool_end():
|
async def test_home_formatter_entity_tag_single_id():
|
||||||
|
"""A <project>[id]</project> tag emits a WsStreamBlock with entity + ids."""
|
||||||
req_id = "req-2"
|
req_id = "req-2"
|
||||||
events = [
|
events = [
|
||||||
("tool_end", {"name": "task_agent", "result": "Found 3 tasks."}),
|
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
||||||
("token", "Here are your tasks."),
|
|
||||||
("mutations", []),
|
("mutations", []),
|
||||||
]
|
]
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
@@ -62,27 +62,86 @@ async def test_home_formatter_entity_ref_from_tool_end():
|
|||||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
assert len(block_frames) == 1
|
assert len(block_frames) == 1
|
||||||
assert block_frames[0].block_type == "entity_ref"
|
assert block_frames[0].block_type == "entity_ref"
|
||||||
assert block_frames[0].data["entity"] == "tasks"
|
assert block_frames[0].data["entity"] == "projects"
|
||||||
assert block_frames[0].data["result"] == "Found 3 tasks."
|
assert block_frames[0].data["ids"] == ["abc-123"]
|
||||||
|
|
||||||
|
# Surrounding text is streamed
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert "Here is your project:" in text
|
||||||
|
assert "All good." in text
|
||||||
|
# The raw tag itself should NOT appear in streamed text
|
||||||
|
assert "<project>" not in text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_unknown_agent_no_block():
|
async def test_home_formatter_entity_tag_multiple_ids():
|
||||||
req_id = "req-3"
|
req_id = "req-3"
|
||||||
events = [
|
events = [
|
||||||
("tool_end", {"name": "unknown_agent", "result": "stuff"}),
|
("token", "Pending:\n<task>[id-1,id-2,id-3]</task>"),
|
||||||
("mutations", []),
|
("mutations", []),
|
||||||
]
|
]
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
frames = await collect(formatter, _stream(*events))
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
assert len(block_frames) == 0 # unknown agent → no entity mapping
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].data["entity"] == "tasks"
|
||||||
|
assert block_frames[0].data["ids"] == ["id-1", "id-2", "id-3"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_multiple_entity_tags():
|
||||||
|
req_id = "req-4"
|
||||||
|
events = [
|
||||||
|
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 2
|
||||||
|
entities = {b.data["entity"] for b in block_frames}
|
||||||
|
assert entities == {"projects", "tasks"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_tag_split_across_tokens():
|
||||||
|
"""Entity tag arrives across two token chunks — still detected."""
|
||||||
|
req_id = "req-5"
|
||||||
|
events = [
|
||||||
|
("token", "See: <project>[abc-"),
|
||||||
|
("token", "123]</project> done"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].data["entity"] == "projects"
|
||||||
|
assert block_frames[0].data["ids"] == ["abc-123"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_tool_end_ignored():
|
||||||
|
"""tool_end events no longer produce blocks — only entity tags do."""
|
||||||
|
req_id = "req-6"
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
||||||
|
("token", "No tags here."),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_mutations_in_stream_end():
|
async def test_home_formatter_mutations_in_stream_end():
|
||||||
req_id = "req-4"
|
req_id = "req-7"
|
||||||
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
||||||
events = [
|
events = [
|
||||||
("token", "Done"),
|
("token", "Done"),
|
||||||
@@ -100,31 +159,13 @@ async def test_home_formatter_mutations_in_stream_end():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_frame_order():
|
async def test_home_formatter_frame_order():
|
||||||
"""stream_start is first, stream_end is last."""
|
"""stream_start is first, stream_end is last."""
|
||||||
req_id = "req-5"
|
req_id = "req-8"
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_multiple_tool_ends():
|
|
||||||
req_id = "req-6"
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
|
||||||
("tool_end", {"name": "project_agent", "result": "2 projects"}),
|
|
||||||
("token", "Overview done."),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
||||||
assert len(block_frames) == 2
|
|
||||||
entities = {b.data["entity"] for b in block_frames}
|
|
||||||
assert entities == {"tasks", "projects"}
|
|
||||||
|
|
||||||
|
|
||||||
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -46,8 +46,7 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
|||||||
|
|
||||||
|
|
||||||
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
|
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
|
||||||
yield "tool_end", {"name": "task_agent", "result": "Found tasks"}
|
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
|
||||||
yield "token", "Hello"
|
|
||||||
yield "mutations", []
|
yield "mutations", []
|
||||||
|
|
||||||
|
|
||||||
@@ -115,7 +114,6 @@ def test_home_request_request_id_propagated(client):
|
|||||||
req_id = "my-unique-req-id"
|
req_id = "my-unique-req-id"
|
||||||
|
|
||||||
async def _stream(user_id, message, context, db_session_factory=None):
|
async def _stream(user_id, message, context, db_session_factory=None):
|
||||||
yield "tool_end", {"name": "note_agent", "result": "ok"}
|
|
||||||
yield "token", "ok"
|
yield "token", "ok"
|
||||||
yield "mutations", []
|
yield "mutations", []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user