feat: HomeFormatter parses inline entity tags instead of tool_end blocks
The supervisor LLM now embeds <type>[id1,id2]</type> entity tags in its response text. The HomeFormatter buffers streamed tokens, detects complete tags across chunk boundaries, and emits WsStreamBlock with entity type + specific IDs. This replaces the old approach of emitting blocks for every tool_end event, which dumped ALL entities regardless of relevance. Also fixes: - NoneType guard on metadata in _run_graph_stream (metadata can be None) - Updated _HOME_SYSTEM prompt with entity tag instructions - Updated all affected tests
This commit is contained in:
@@ -32,7 +32,7 @@ async def collect(formatter, event_stream):
|
||||
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_text_token():
|
||||
async def test_home_formatter_plain_text():
|
||||
req_id = "req-1"
|
||||
events = [
|
||||
("token", "Hello world"),
|
||||
@@ -49,11 +49,11 @@ async def test_home_formatter_text_token():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_entity_ref_from_tool_end():
|
||||
async def test_home_formatter_entity_tag_single_id():
|
||||
"""A <project>[id]</project> tag emits a WsStreamBlock with entity + ids."""
|
||||
req_id = "req-2"
|
||||
events = [
|
||||
("tool_end", {"name": "task_agent", "result": "Found 3 tasks."}),
|
||||
("token", "Here are your tasks."),
|
||||
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
@@ -62,27 +62,86 @@ async def test_home_formatter_entity_ref_from_tool_end():
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].block_type == "entity_ref"
|
||||
assert block_frames[0].data["entity"] == "tasks"
|
||||
assert block_frames[0].data["result"] == "Found 3 tasks."
|
||||
assert block_frames[0].data["entity"] == "projects"
|
||||
assert block_frames[0].data["ids"] == ["abc-123"]
|
||||
|
||||
# Surrounding text is streamed
|
||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||
assert "Here is your project:" in text
|
||||
assert "All good." in text
|
||||
# The raw tag itself should NOT appear in streamed text
|
||||
assert "<project>" not in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_unknown_agent_no_block():
|
||||
async def test_home_formatter_entity_tag_multiple_ids():
|
||||
req_id = "req-3"
|
||||
events = [
|
||||
("tool_end", {"name": "unknown_agent", "result": "stuff"}),
|
||||
("token", "Pending:\n<task>[id-1,id-2,id-3]</task>"),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 0 # unknown agent → no entity mapping
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].data["entity"] == "tasks"
|
||||
assert block_frames[0].data["ids"] == ["id-1", "id-2", "id-3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_multiple_entity_tags():
|
||||
req_id = "req-4"
|
||||
events = [
|
||||
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 2
|
||||
entities = {b.data["entity"] for b in block_frames}
|
||||
assert entities == {"projects", "tasks"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_tag_split_across_tokens():
|
||||
"""Entity tag arrives across two token chunks — still detected."""
|
||||
req_id = "req-5"
|
||||
events = [
|
||||
("token", "See: <project>[abc-"),
|
||||
("token", "123]</project> done"),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].data["entity"] == "projects"
|
||||
assert block_frames[0].data["ids"] == ["abc-123"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_tool_end_ignored():
|
||||
"""tool_end events no longer produce blocks — only entity tags do."""
|
||||
req_id = "req-6"
|
||||
events = [
|
||||
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
||||
("token", "No tags here."),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_mutations_in_stream_end():
|
||||
req_id = "req-4"
|
||||
req_id = "req-7"
|
||||
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
||||
events = [
|
||||
("token", "Done"),
|
||||
@@ -100,31 +159,13 @@ async def test_home_formatter_mutations_in_stream_end():
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_frame_order():
|
||||
"""stream_start is first, stream_end is last."""
|
||||
req_id = "req-5"
|
||||
req_id = "req-8"
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
||||
assert isinstance(frames[0], WsStreamStart)
|
||||
assert isinstance(frames[-1], WsStreamEnd)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_multiple_tool_ends():
|
||||
req_id = "req-6"
|
||||
events = [
|
||||
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
||||
("tool_end", {"name": "project_agent", "result": "2 projects"}),
|
||||
("token", "Overview done."),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 2
|
||||
entities = {b.data["entity"] for b in block_frames}
|
||||
assert entities == {"tasks", "projects"}
|
||||
|
||||
|
||||
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -46,8 +46,7 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
||||
|
||||
|
||||
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
|
||||
yield "tool_end", {"name": "task_agent", "result": "Found tasks"}
|
||||
yield "token", "Hello"
|
||||
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
|
||||
yield "mutations", []
|
||||
|
||||
|
||||
@@ -115,7 +114,6 @@ def test_home_request_request_id_propagated(client):
|
||||
req_id = "my-unique-req-id"
|
||||
|
||||
async def _stream(user_id, message, context, db_session_factory=None):
|
||||
yield "tool_end", {"name": "note_agent", "result": "ok"}
|
||||
yield "token", "ok"
|
||||
yield "mutations", []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user