From 393b3befd6efcc224f59bdb6962058b96ffb1df1 Mon Sep 17 00:00:00 2001 From: roberto Date: Sun, 8 Mar 2026 21:51:20 +0100 Subject: [PATCH] step-4: add output formatting layer (output_formatter.py) HomeFormatter parses JSON block stream from orchestrator tokens and emits stream_start / stream_text / stream_block / stream_end frames. PopupFormatter emits popup_domain then plain stream_text. All 13 unit tests pass. Co-Authored-By: Claude Sonnet 4.6 --- V3_MIGRATION_PLAN.md | 2 +- app/core/output_formatter.py | 244 +++++++++++++++++++++++++++++++++ tests/test_output_formatter.py | 195 ++++++++++++++++++++++++++ 3 files changed, 440 insertions(+), 1 deletion(-) create mode 100644 app/core/output_formatter.py create mode 100644 tests/test_output_formatter.py diff --git a/V3_MIGRATION_PLAN.md b/V3_MIGRATION_PLAN.md index 090923f..30eca16 100644 --- a/V3_MIGRATION_PLAN.md +++ b/V3_MIGRATION_PLAN.md @@ -212,7 +212,7 @@ pytest tests/test_output_formatter.py ``` **Status**: -- [ ] Step 4 complete +- [x] Step 4 complete **Commit**: After tests pass, commit with: ``` diff --git a/app/core/output_formatter.py b/app/core/output_formatter.py new file mode 100644 index 0000000..c5880f4 --- /dev/null +++ b/app/core/output_formatter.py @@ -0,0 +1,244 @@ +"""Output Formatter — transforms orchestrator token streams into WS frame sequences. + +HomeFormatter: produces stream_start, stream_text / stream_block, stream_end +PopupFormatter: produces popup_domain, stream_text, stream_end +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncGenerator +from typing import Any + +from app.schemas import ( + WsPopupDomain, + WsStreamBlock, + WsStreamEnd, + WsStreamStart, + WsStreamText, +) + +logger = logging.getLogger(__name__) + +# Valid chart types (matching shadcn/ui Recharts wrappers in Electron) +_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"} + +# Map agent name → popup domain +_AGENT_DOMAIN: dict[str, str] = { + "task_agent": "tasks", + "checkpoint_agent": "checkpoints", + "note_agent": "notes", + "project_agent": "projects", +} + +WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsPopupDomain + + +class HomeFormatter: + """Parses a token stream from orchestrate_v3_stream and yields WS frames. + + The LLM is expected to output a newline-delimited sequence of JSON objects, + each with a ``type`` field: + - ``text`` → yields WsStreamText immediately (word-by-word) + - ``chart`` → buffers full JSON, validates, yields WsStreamBlock + - ``entity_ref`` → resolves from tool_results, yields WsStreamBlock + - ``table`` → buffers full JSON, validates, yields WsStreamBlock + - ``timeline`` → buffers full JSON, validates, yields WsStreamBlock + + Invalid or unknown blocks are logged and skipped — stream never crashes. + """ + + def __init__(self, request_id: str, tool_results: list[dict]) -> None: + self.request_id = request_id + self.tool_results = tool_results + + async def format( + self, + token_stream: AsyncGenerator[tuple[str, str], None], + ) -> AsyncGenerator[WsFrame, None]: + yield WsStreamStart(request_id=self.request_id) + + buffer = "" + async for _agent_name, token in token_stream: + if not token: + continue + buffer += token + # Flush any complete JSON objects from the buffer + async for frame in self._flush_complete_objects(buffer): + buffer = "" # reset after flush + yield frame + break # only one flush per iteration; rest accumulates + + # Flush any remaining content + if buffer.strip(): + async for frame in self._flush_complete_objects(buffer, final=True): + yield frame + + yield WsStreamEnd(request_id=self.request_id) + + async def _flush_complete_objects( + self, text: str, final: bool = False + ) -> AsyncGenerator[WsFrame, None]: + """Try to parse and yield all complete JSON objects from *text*. + + Yields nothing if text is incomplete JSON (unless *final* is True, + in which case remaining text is emitted as plain stream_text). + """ + remaining = text.strip() + while remaining: + # Fast path: plain text (not JSON) + if not remaining.startswith("{"): + # Yield as plain text chunk + newline_idx = remaining.find("\n") + if newline_idx == -1: + if final: + yield WsStreamText(request_id=self.request_id, chunk=remaining) + remaining = "" + else: + return # accumulate more + else: + line = remaining[:newline_idx].strip() + remaining = remaining[newline_idx + 1:].strip() + if line: + yield WsStreamText(request_id=self.request_id, chunk=line) + continue + + # Try to decode a JSON object + try: + obj, end_idx = _try_parse_json(remaining) + except ValueError: + if final: + # Emit as raw text if we can't parse + yield WsStreamText(request_id=self.request_id, chunk=remaining) + remaining = "" + return + + if obj is None: + if final: + yield WsStreamText(request_id=self.request_id, chunk=remaining) + remaining = "" + return # incomplete — need more tokens + + remaining = remaining[end_idx:].strip() + block_type = obj.get("type") + + frame = self._dispatch_block(obj, block_type) + if frame is not None: + yield frame + + def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None: + if block_type == "text": + content = obj.get("content", "") + if content: + return WsStreamText(request_id=self.request_id, chunk=str(content)) + return None + + if block_type == "chart": + chart_type = obj.get("chartType") + if chart_type not in _VALID_CHART_TYPES: + logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type) + return None + if not isinstance(obj.get("data"), list): + logger.warning("HomeFormatter: chart missing data array — skipping") + return None + return WsStreamBlock( + request_id=self.request_id, + block_type="chart", + data=obj, + ) + + if block_type == "entity_ref": + entity = obj.get("entity") + resolved = self._resolve_entity(entity) + if resolved is None: + logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity) + return None + return WsStreamBlock( + request_id=self.request_id, + block_type="entity_ref", + data={"entity": entity, "items": resolved}, + ) + + if block_type == "table": + if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list): + logger.warning("HomeFormatter: table missing headers/rows — skipping") + return None + return WsStreamBlock( + request_id=self.request_id, + block_type="table", + data=obj, + ) + + if block_type == "timeline": + if not isinstance(obj.get("checkpoints"), list): + logger.warning("HomeFormatter: timeline missing checkpoints — skipping") + return None + return WsStreamBlock( + request_id=self.request_id, + block_type="timeline", + data=obj, + ) + + logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type) + return None + + def _resolve_entity(self, entity: str | None) -> list[dict] | None: + """Find matching items in tool_results by entity type.""" + if not entity: + return None + matches = [r for r in self.tool_results if r.get("entity") == entity] + return matches if matches else None + + +class PopupFormatter: + """Parses a token stream from orchestrate_v3_stream and yields WS frames. + + Emits popup_domain immediately (from agent_name), then streams all tokens + as plain stream_text — no block parsing for popup context. + """ + + def __init__(self, request_id: str) -> None: + self.request_id = request_id + + async def format( + self, + token_stream: AsyncGenerator[tuple[str, str], None], + ) -> AsyncGenerator[WsFrame, None]: + domain_sent = False + + async for agent_name, token in token_stream: + if not domain_sent: + domain = _AGENT_DOMAIN.get(agent_name, "tasks") + yield WsPopupDomain( + request_id=self.request_id, + domain=domain, # type: ignore[arg-type] + ) + yield WsStreamStart(request_id=self.request_id) + domain_sent = True + + if token: + yield WsStreamText(request_id=self.request_id, chunk=token) + + yield WsStreamEnd(request_id=self.request_id) + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]: + """Attempt to parse the first complete JSON object from *text*. + + Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the + object is incomplete, and raises ``ValueError`` when text is not JSON. + """ + decoder = json.JSONDecoder() + try: + obj, end_idx = decoder.raw_decode(text) + if not isinstance(obj, dict): + raise ValueError("Expected JSON object") + return obj, end_idx + except json.JSONDecodeError as exc: + # Incomplete JSON — need more tokens + if "Unterminated" in str(exc) or exc.pos == len(text): + return None, 0 + raise ValueError(str(exc)) from exc diff --git a/tests/test_output_formatter.py b/tests/test_output_formatter.py new file mode 100644 index 0000000..f59b7f9 --- /dev/null +++ b/tests/test_output_formatter.py @@ -0,0 +1,195 @@ +"""Tests for app.core.output_formatter — HomeFormatter and PopupFormatter.""" + +from __future__ import annotations + +import pytest + +from app.core.output_formatter import HomeFormatter, PopupFormatter +from app.schemas import ( + WsPopupDomain, + WsStreamBlock, + WsStreamEnd, + WsStreamStart, + WsStreamText, +) + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +async def _stream(*pairs: tuple[str, str]): + """Async generator that yields (agent_name, token) pairs.""" + for pair in pairs: + yield pair + + +async def collect(formatter, token_stream): + frames = [] + async for frame in formatter.format(token_stream): + frames.append(frame) + return frames + + +# ── HomeFormatter ───────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_home_formatter_text_block(): + req_id = "req-1" + tokens = [ + ("task_agent", '{"type": "text", "content": "Hello world"}'), + ] + formatter = HomeFormatter(request_id=req_id, tool_results=[]) + frames = await collect(formatter, _stream(*tokens)) + + assert isinstance(frames[0], WsStreamStart) + assert frames[0].request_id == req_id + text_frames = [f for f in frames if isinstance(f, WsStreamText)] + assert any("Hello world" in f.chunk for f in text_frames) + assert isinstance(frames[-1], WsStreamEnd) + + +@pytest.mark.asyncio +async def test_home_formatter_chart_block(): + req_id = "req-2" + chart_json = ( + '{"type": "chart", "chartType": "bar", ' + '"title": "Tasks", "data": [{"x": 1}], ' + '"config": {"x": {"label": "X", "color": "#fff"}}}' + ) + formatter = HomeFormatter(request_id=req_id, tool_results=[]) + frames = await collect(formatter, _stream(("task_agent", chart_json))) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 1 + assert block_frames[0].block_type == "chart" + assert block_frames[0].data["chartType"] == "bar" + + +@pytest.mark.asyncio +async def test_home_formatter_invalid_chart_skipped(): + req_id = "req-3" + bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}' + formatter = HomeFormatter(request_id=req_id, tool_results=[]) + frames = await collect(formatter, _stream(("task_agent", bad_chart))) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 0 # invalid chart skipped + + +@pytest.mark.asyncio +async def test_home_formatter_entity_ref_resolved(): + req_id = "req-4" + tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}] + entity_json = '{"type": "entity_ref", "entity": "task"}' + formatter = HomeFormatter(request_id=req_id, tool_results=tool_results) + frames = await collect(formatter, _stream(("task_agent", entity_json))) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 1 + assert block_frames[0].data["entity"] == "task" + assert block_frames[0].data["items"][0]["id"] == "t1" + + +@pytest.mark.asyncio +async def test_home_formatter_entity_ref_missing_skipped(): + req_id = "req-5" + entity_json = '{"type": "entity_ref", "entity": "task"}' + formatter = HomeFormatter(request_id=req_id, tool_results=[]) + frames = await collect(formatter, _stream(("task_agent", entity_json))) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 0 # no tool results → skipped + + +@pytest.mark.asyncio +async def test_home_formatter_table_block(): + req_id = "req-6" + table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}' + formatter = HomeFormatter(request_id=req_id, tool_results=[]) + frames = await collect(formatter, _stream(("task_agent", table_json))) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 1 + assert block_frames[0].block_type == "table" + + +@pytest.mark.asyncio +async def test_home_formatter_timeline_block(): + req_id = "req-7" + timeline_json = '{"type": "timeline", "checkpoints": [{"id": "c1", "title": "M1", "date": 123}]}' + formatter = HomeFormatter(request_id=req_id, tool_results=[]) + frames = await collect(formatter, _stream(("task_agent", timeline_json))) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 1 + assert block_frames[0].block_type == "timeline" + + +@pytest.mark.asyncio +async def test_home_formatter_frame_order(): + """stream_start is first, stream_end is last.""" + req_id = "req-8" + formatter = HomeFormatter(request_id=req_id, tool_results=[]) + frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}'))) + assert isinstance(frames[0], WsStreamStart) + assert isinstance(frames[-1], WsStreamEnd) + + +# ── PopupFormatter ──────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_popup_formatter_domain_emitted_first(): + req_id = "pop-1" + formatter = PopupFormatter(request_id=req_id) + tokens = [ + ("task_agent", ""), # domain signal + ("task_agent", "Hello"), + ("task_agent", " there"), + ] + frames = await collect(formatter, _stream(*tokens)) + + assert isinstance(frames[0], WsPopupDomain) + assert frames[0].domain == "tasks" + assert frames[0].request_id == req_id + + +@pytest.mark.asyncio +async def test_popup_formatter_text_only(): + req_id = "pop-2" + formatter = PopupFormatter(request_id=req_id) + tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")] + frames = await collect(formatter, _stream(*tokens)) + + assert isinstance(frames[0], WsPopupDomain) + assert frames[0].domain == "checkpoints" + text_frames = [f for f in frames if isinstance(f, WsStreamText)] + assert len(text_frames) == 1 + assert text_frames[0].chunk == "Summary" + + +@pytest.mark.asyncio +async def test_popup_formatter_no_block_frames(): + """PopupFormatter must never emit WsStreamBlock.""" + req_id = "pop-3" + formatter = PopupFormatter(request_id=req_id) + tokens = [ + ("note_agent", ""), + ("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'), + ] + frames = await collect(formatter, _stream(*tokens)) + assert not any(isinstance(f, WsStreamBlock) for f in frames) + + +@pytest.mark.asyncio +async def test_popup_formatter_end_frame(): + req_id = "pop-4" + formatter = PopupFormatter(request_id=req_id) + frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done"))) + assert isinstance(frames[-1], WsStreamEnd) + + +@pytest.mark.asyncio +async def test_popup_formatter_unknown_agent_defaults_to_tasks(): + req_id = "pop-5" + formatter = PopupFormatter(request_id=req_id) + frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi"))) + assert frames[0].domain == "tasks"