diff --git a/app/core/output_formatter.py b/app/core/output_formatter.py index d6a4833..a8e43f2 100644 --- a/app/core/output_formatter.py +++ b/app/core/output_formatter.py @@ -6,9 +6,9 @@ Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``: * ``("mutations", list)`` — collected CRUD mutations for ``stream_end`` HomeFormatter: - * Buffers text tokens and parses inline entity tags - ``[id1,id2]`` → emits ``WsStreamBlock`` (entity_ref with IDs) - * Streams surrounding text → emits ``WsStreamText`` + * Streams text tokens as-is → emits ``WsStreamText`` + (text may contain inline ``[id,...]`` entity tags + for the frontend to parse and render as interactive components) * Attaches mutations → injects into ``WsStreamEnd`` FloatingFormatter: @@ -20,13 +20,11 @@ FloatingFormatter: from __future__ import annotations import logging -import re from collections.abc import AsyncGenerator from typing import Any from app.schemas import ( WsFloatingDomain, - WsStreamBlock, WsStreamEnd, WsStreamStart, WsStreamText, @@ -42,91 +40,21 @@ _AGENT_DOMAIN: dict[str, str] = { "project_agent": "projects", } -# Regex for complete inline entity tags: [id1,id2] -_ENTITY_TAG_RE = re.compile( - r"<(task|project|note|timeline)>\[([^\]]+)\]" -) - -# Tag name → plural entity type for the WsStreamBlock data -_TAG_ENTITY: dict[str, str] = { - "task": "tasks", - "project": "projects", - "note": "notes", - "timeline": "timelines", -} - -WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain +WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain class HomeFormatter: """Consumes a deep-agent event stream and yields WS frames for the Home view. - The supervisor's response contains inline entity tags like - ``[abc-123]``. This formatter detects them, - emits ``WsStreamBlock(block_type="entity_ref")`` with the entity - type and IDs, and forwards surrounding text as ``WsStreamText``. + Text tokens are forwarded as-is via ``WsStreamText``. The supervisor + embeds ``[id1,id2]`` entity tags inline — the frontend + is responsible for parsing those and rendering interactive components. Mutations are attached to ``WsStreamEnd``. """ def __init__(self, request_id: str) -> None: self.request_id = request_id self._mutations: list[dict] = [] - self._buffer: str = "" - - def _flush_buffer(self, force: bool = False): - """Extract complete entity tags and text from the buffer. - - Yields (frame_type, data) pairs: - ("text", str) — plain text to send as WsStreamText - ("block", dict) — entity_ref block to send as WsStreamBlock - - When *force* is True (end of stream), the entire buffer is flushed. - Otherwise, text after the last unmatched ``<`` is held back in case - it is the start of an entity tag arriving across token boundaries. - """ - buf = self._buffer - - while True: - m = _ENTITY_TAG_RE.search(buf) - if not m: - break - - # Text before the tag - before = buf[: m.start()] - if before: - yield ("text", before) - - # The entity tag itself → a block - tag_type = m.group(1) - raw_ids = m.group(2) - ids = [i.strip() for i in raw_ids.split(",") if i.strip()] - yield ( - "block", - { - "entity": _TAG_ENTITY[tag_type], - "ids": ids, - }, - ) - - buf = buf[m.end() :] - - if force: - # End of stream — flush everything that remains - if buf: - yield ("text", buf) - self._buffer = "" - else: - # Keep a potential partial tag (text after last '<') in the buffer - last_lt = buf.rfind("<") - if last_lt != -1: - safe = buf[:last_lt] - if safe: - yield ("text", safe) - self._buffer = buf[last_lt:] - else: - if buf: - yield ("text", buf) - self._buffer = "" async def format( self, @@ -137,36 +65,11 @@ class HomeFormatter: async for event_type, data in event_stream: if event_type == "token": if data: - self._buffer += data - for ftype, fdata in self._flush_buffer(): - if ftype == "text": - yield WsStreamText( - request_id=self.request_id, chunk=fdata - ) - elif ftype == "block": - yield WsStreamBlock( - request_id=self.request_id, - block_type="entity_ref", - data=fdata, - ) + yield WsStreamText(request_id=self.request_id, chunk=data) elif event_type == "mutations": self._mutations = data or [] - # tool_end events are intentionally ignored — the supervisor - # embeds relevant entity IDs inline via [ids] tags. - - # Flush any remaining buffer content - for ftype, fdata in self._flush_buffer(force=True): - if ftype == "text": - yield WsStreamText(request_id=self.request_id, chunk=fdata) - elif ftype == "block": - yield WsStreamBlock( - request_id=self.request_id, - block_type="entity_ref", - data=fdata, - ) - yield WsStreamEnd( request_id=self.request_id, mutations=[ diff --git a/tests/test_output_formatter.py b/tests/test_output_formatter.py index cb2d68a..087d2b9 100644 --- a/tests/test_output_formatter.py +++ b/tests/test_output_formatter.py @@ -7,7 +7,6 @@ import pytest from app.core.output_formatter import HomeFormatter, FloatingFormatter from app.schemas import ( WsFloatingDomain, - WsStreamBlock, WsStreamEnd, WsStreamStart, WsStreamText, @@ -49,8 +48,8 @@ async def test_home_formatter_plain_text(): @pytest.mark.asyncio -async def test_home_formatter_entity_tag_single_id(): - """A [id] tag emits a WsStreamBlock with entity + ids.""" +async def test_home_formatter_entity_tags_passed_through(): + """Entity tags are streamed as-is — the frontend parses them.""" req_id = "req-2" events = [ ("token", "Here is your project:\n[abc-123]\nAll good."), @@ -59,39 +58,15 @@ async def test_home_formatter_entity_tag_single_id(): formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 1 - assert block_frames[0].block_type == "entity_ref" - assert block_frames[0].data["entity"] == "projects" - assert block_frames[0].data["ids"] == ["abc-123"] - - # Surrounding text is streamed text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText)) + assert "[abc-123]" in text assert "Here is your project:" in text assert "All good." in text - # The raw tag itself should NOT appear in streamed text - assert "" not in text @pytest.mark.asyncio -async def test_home_formatter_entity_tag_multiple_ids(): +async def test_home_formatter_multiple_tags_passed_through(): req_id = "req-3" - events = [ - ("token", "Pending:\n[id-1,id-2,id-3]"), - ("mutations", []), - ] - formatter = HomeFormatter(request_id=req_id) - frames = await collect(formatter, _stream(*events)) - - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 1 - assert block_frames[0].data["entity"] == "tasks" - assert block_frames[0].data["ids"] == ["id-1", "id-2", "id-3"] - - -@pytest.mark.asyncio -async def test_home_formatter_multiple_entity_tags(): - req_id = "req-4" events = [ ("token", "[p1]\nText\n[t1,t2]"), ("mutations", []), @@ -99,34 +74,15 @@ async def test_home_formatter_multiple_entity_tags(): formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 2 - entities = {b.data["entity"] for b in block_frames} - assert entities == {"projects", "tasks"} - - -@pytest.mark.asyncio -async def test_home_formatter_tag_split_across_tokens(): - """Entity tag arrives across two token chunks — still detected.""" - req_id = "req-5" - events = [ - ("token", "See: [abc-"), - ("token", "123] done"), - ("mutations", []), - ] - formatter = HomeFormatter(request_id=req_id) - frames = await collect(formatter, _stream(*events)) - - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 1 - assert block_frames[0].data["entity"] == "projects" - assert block_frames[0].data["ids"] == ["abc-123"] + text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText)) + assert "[p1]" in text + assert "[t1,t2]" in text @pytest.mark.asyncio async def test_home_formatter_tool_end_ignored(): - """tool_end events no longer produce blocks — only entity tags do.""" - req_id = "req-6" + """tool_end events are silently ignored by HomeFormatter.""" + req_id = "req-4" events = [ ("tool_end", {"name": "task_agent", "result": "3 tasks"}), ("token", "No tags here."), @@ -135,13 +91,13 @@ async def test_home_formatter_tool_end_ignored(): formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 0 + text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText)) + assert text == "No tags here." @pytest.mark.asyncio async def test_home_formatter_mutations_in_stream_end(): - req_id = "req-7" + req_id = "req-5" muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}] events = [ ("token", "Done"), @@ -159,7 +115,7 @@ async def test_home_formatter_mutations_in_stream_end(): @pytest.mark.asyncio async def test_home_formatter_frame_order(): """stream_start is first, stream_end is last.""" - req_id = "req-8" + req_id = "req-6" formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", []))) assert isinstance(frames[0], WsStreamStart) @@ -203,8 +159,8 @@ async def test_floating_formatter_text_only(): @pytest.mark.asyncio -async def test_floating_formatter_no_block_frames(): - """FloatingFormatter must never emit WsStreamBlock.""" +async def test_floating_formatter_no_entity_tags(): + """FloatingFormatter never emits entity tag blocks.""" req_id = "pop-3" formatter = FloatingFormatter(request_id=req_id) events = [ @@ -213,7 +169,9 @@ async def test_floating_formatter_no_block_frames(): ("mutations", []), ] frames = await collect(formatter, _stream(*events)) - assert not any(isinstance(f, WsStreamBlock) for f in frames) + # Only expected frame types + for f in frames: + assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd)) @pytest.mark.asyncio