simplify HomeFormatter to pass-through — frontend handles entity tag parsing
This commit is contained in:
@@ -6,9 +6,9 @@ Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
||||
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
||||
|
||||
HomeFormatter:
|
||||
* Buffers text tokens and parses inline entity tags
|
||||
``<type>[id1,id2]</type>`` → emits ``WsStreamBlock`` (entity_ref with IDs)
|
||||
* Streams surrounding text → emits ``WsStreamText``
|
||||
* Streams text tokens as-is → emits ``WsStreamText``
|
||||
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
||||
for the frontend to parse and render as interactive components)
|
||||
* Attaches mutations → injects into ``WsStreamEnd``
|
||||
|
||||
FloatingFormatter:
|
||||
@@ -20,13 +20,11 @@ FloatingFormatter:
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import (
|
||||
WsFloatingDomain,
|
||||
WsStreamBlock,
|
||||
WsStreamEnd,
|
||||
WsStreamStart,
|
||||
WsStreamText,
|
||||
@@ -42,91 +40,21 @@ _AGENT_DOMAIN: dict[str, str] = {
|
||||
"project_agent": "projects",
|
||||
}
|
||||
|
||||
# Regex for complete inline entity tags: <task>[id1,id2]</task>
|
||||
_ENTITY_TAG_RE = re.compile(
|
||||
r"<(task|project|note|timeline)>\[([^\]]+)\]</\1>"
|
||||
)
|
||||
|
||||
# Tag name → plural entity type for the WsStreamBlock data
|
||||
_TAG_ENTITY: dict[str, str] = {
|
||||
"task": "tasks",
|
||||
"project": "projects",
|
||||
"note": "notes",
|
||||
"timeline": "timelines",
|
||||
}
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||
|
||||
|
||||
class HomeFormatter:
|
||||
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
||||
|
||||
The supervisor's response contains inline entity tags like
|
||||
``<project>[abc-123]</project>``. This formatter detects them,
|
||||
emits ``WsStreamBlock(block_type="entity_ref")`` with the entity
|
||||
type and IDs, and forwards surrounding text as ``WsStreamText``.
|
||||
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
||||
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
||||
is responsible for parsing those and rendering interactive components.
|
||||
Mutations are attached to ``WsStreamEnd``.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._mutations: list[dict] = []
|
||||
self._buffer: str = ""
|
||||
|
||||
def _flush_buffer(self, force: bool = False):
|
||||
"""Extract complete entity tags and text from the buffer.
|
||||
|
||||
Yields (frame_type, data) pairs:
|
||||
("text", str) — plain text to send as WsStreamText
|
||||
("block", dict) — entity_ref block to send as WsStreamBlock
|
||||
|
||||
When *force* is True (end of stream), the entire buffer is flushed.
|
||||
Otherwise, text after the last unmatched ``<`` is held back in case
|
||||
it is the start of an entity tag arriving across token boundaries.
|
||||
"""
|
||||
buf = self._buffer
|
||||
|
||||
while True:
|
||||
m = _ENTITY_TAG_RE.search(buf)
|
||||
if not m:
|
||||
break
|
||||
|
||||
# Text before the tag
|
||||
before = buf[: m.start()]
|
||||
if before:
|
||||
yield ("text", before)
|
||||
|
||||
# The entity tag itself → a block
|
||||
tag_type = m.group(1)
|
||||
raw_ids = m.group(2)
|
||||
ids = [i.strip() for i in raw_ids.split(",") if i.strip()]
|
||||
yield (
|
||||
"block",
|
||||
{
|
||||
"entity": _TAG_ENTITY[tag_type],
|
||||
"ids": ids,
|
||||
},
|
||||
)
|
||||
|
||||
buf = buf[m.end() :]
|
||||
|
||||
if force:
|
||||
# End of stream — flush everything that remains
|
||||
if buf:
|
||||
yield ("text", buf)
|
||||
self._buffer = ""
|
||||
else:
|
||||
# Keep a potential partial tag (text after last '<') in the buffer
|
||||
last_lt = buf.rfind("<")
|
||||
if last_lt != -1:
|
||||
safe = buf[:last_lt]
|
||||
if safe:
|
||||
yield ("text", safe)
|
||||
self._buffer = buf[last_lt:]
|
||||
else:
|
||||
if buf:
|
||||
yield ("text", buf)
|
||||
self._buffer = ""
|
||||
|
||||
async def format(
|
||||
self,
|
||||
@@ -137,36 +65,11 @@ class HomeFormatter:
|
||||
async for event_type, data in event_stream:
|
||||
if event_type == "token":
|
||||
if data:
|
||||
self._buffer += data
|
||||
for ftype, fdata in self._flush_buffer():
|
||||
if ftype == "text":
|
||||
yield WsStreamText(
|
||||
request_id=self.request_id, chunk=fdata
|
||||
)
|
||||
elif ftype == "block":
|
||||
yield WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="entity_ref",
|
||||
data=fdata,
|
||||
)
|
||||
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||
|
||||
elif event_type == "mutations":
|
||||
self._mutations = data or []
|
||||
|
||||
# tool_end events are intentionally ignored — the supervisor
|
||||
# embeds relevant entity IDs inline via <type>[ids]</type> tags.
|
||||
|
||||
# Flush any remaining buffer content
|
||||
for ftype, fdata in self._flush_buffer(force=True):
|
||||
if ftype == "text":
|
||||
yield WsStreamText(request_id=self.request_id, chunk=fdata)
|
||||
elif ftype == "block":
|
||||
yield WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="entity_ref",
|
||||
data=fdata,
|
||||
)
|
||||
|
||||
yield WsStreamEnd(
|
||||
request_id=self.request_id,
|
||||
mutations=[
|
||||
|
||||
@@ -7,7 +7,6 @@ import pytest
|
||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||
from app.schemas import (
|
||||
WsFloatingDomain,
|
||||
WsStreamBlock,
|
||||
WsStreamEnd,
|
||||
WsStreamStart,
|
||||
WsStreamText,
|
||||
@@ -49,8 +48,8 @@ async def test_home_formatter_plain_text():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_entity_tag_single_id():
|
||||
"""A <project>[id]</project> tag emits a WsStreamBlock with entity + ids."""
|
||||
async def test_home_formatter_entity_tags_passed_through():
|
||||
"""Entity tags are streamed as-is — the frontend parses them."""
|
||||
req_id = "req-2"
|
||||
events = [
|
||||
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
||||
@@ -59,39 +58,15 @@ async def test_home_formatter_entity_tag_single_id():
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].block_type == "entity_ref"
|
||||
assert block_frames[0].data["entity"] == "projects"
|
||||
assert block_frames[0].data["ids"] == ["abc-123"]
|
||||
|
||||
# Surrounding text is streamed
|
||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||
assert "<project>[abc-123]</project>" in text
|
||||
assert "Here is your project:" in text
|
||||
assert "All good." in text
|
||||
# The raw tag itself should NOT appear in streamed text
|
||||
assert "<project>" not in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_entity_tag_multiple_ids():
|
||||
async def test_home_formatter_multiple_tags_passed_through():
|
||||
req_id = "req-3"
|
||||
events = [
|
||||
("token", "Pending:\n<task>[id-1,id-2,id-3]</task>"),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].data["entity"] == "tasks"
|
||||
assert block_frames[0].data["ids"] == ["id-1", "id-2", "id-3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_multiple_entity_tags():
|
||||
req_id = "req-4"
|
||||
events = [
|
||||
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
||||
("mutations", []),
|
||||
@@ -99,34 +74,15 @@ async def test_home_formatter_multiple_entity_tags():
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 2
|
||||
entities = {b.data["entity"] for b in block_frames}
|
||||
assert entities == {"projects", "tasks"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_tag_split_across_tokens():
|
||||
"""Entity tag arrives across two token chunks — still detected."""
|
||||
req_id = "req-5"
|
||||
events = [
|
||||
("token", "See: <project>[abc-"),
|
||||
("token", "123]</project> done"),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].data["entity"] == "projects"
|
||||
assert block_frames[0].data["ids"] == ["abc-123"]
|
||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||
assert "<project>[p1]</project>" in text
|
||||
assert "<task>[t1,t2]</task>" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_tool_end_ignored():
|
||||
"""tool_end events no longer produce blocks — only entity tags do."""
|
||||
req_id = "req-6"
|
||||
"""tool_end events are silently ignored by HomeFormatter."""
|
||||
req_id = "req-4"
|
||||
events = [
|
||||
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
||||
("token", "No tags here."),
|
||||
@@ -135,13 +91,13 @@ async def test_home_formatter_tool_end_ignored():
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 0
|
||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||
assert text == "No tags here."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_mutations_in_stream_end():
|
||||
req_id = "req-7"
|
||||
req_id = "req-5"
|
||||
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
||||
events = [
|
||||
("token", "Done"),
|
||||
@@ -159,7 +115,7 @@ async def test_home_formatter_mutations_in_stream_end():
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_frame_order():
|
||||
"""stream_start is first, stream_end is last."""
|
||||
req_id = "req-8"
|
||||
req_id = "req-6"
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
||||
assert isinstance(frames[0], WsStreamStart)
|
||||
@@ -203,8 +159,8 @@ async def test_floating_formatter_text_only():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_floating_formatter_no_block_frames():
|
||||
"""FloatingFormatter must never emit WsStreamBlock."""
|
||||
async def test_floating_formatter_no_entity_tags():
|
||||
"""FloatingFormatter never emits entity tag blocks."""
|
||||
req_id = "pop-3"
|
||||
formatter = FloatingFormatter(request_id=req_id)
|
||||
events = [
|
||||
@@ -213,7 +169,9 @@ async def test_floating_formatter_no_block_frames():
|
||||
("mutations", []),
|
||||
]
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
assert not any(isinstance(f, WsStreamBlock) for f in frames)
|
||||
# Only expected frame types
|
||||
for f in frames:
|
||||
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user