step-4: add output formatting layer (output_formatter.py)

HomeFormatter parses JSON block stream from orchestrator tokens and emits
stream_start / stream_text / stream_block / stream_end frames.
PopupFormatter emits popup_domain then plain stream_text.
All 13 unit tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-08 21:51:20 +01:00
parent 2c08275934
commit 393b3befd6
3 changed files with 440 additions and 1 deletions

View File

@@ -212,7 +212,7 @@ pytest tests/test_output_formatter.py
``` ```
**Status**: **Status**:
- [ ] Step 4 complete - [x] Step 4 complete
**Commit**: After tests pass, commit with: **Commit**: After tests pass, commit with:
``` ```

View File

@@ -0,0 +1,244 @@
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
PopupFormatter: produces popup_domain, stream_text, stream_end
"""
from __future__ import annotations
import json
import logging
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import (
WsPopupDomain,
WsStreamBlock,
WsStreamEnd,
WsStreamStart,
WsStreamText,
)
logger = logging.getLogger(__name__)
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
# Map agent name → popup domain
_AGENT_DOMAIN: dict[str, str] = {
"task_agent": "tasks",
"checkpoint_agent": "checkpoints",
"note_agent": "notes",
"project_agent": "projects",
}
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsPopupDomain
class HomeFormatter:
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
The LLM is expected to output a newline-delimited sequence of JSON objects,
each with a ``type`` field:
- ``text`` → yields WsStreamText immediately (word-by-word)
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
Invalid or unknown blocks are logged and skipped — stream never crashes.
"""
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
self.request_id = request_id
self.tool_results = tool_results
async def format(
self,
token_stream: AsyncGenerator[tuple[str, str], None],
) -> AsyncGenerator[WsFrame, None]:
yield WsStreamStart(request_id=self.request_id)
buffer = ""
async for _agent_name, token in token_stream:
if not token:
continue
buffer += token
# Flush any complete JSON objects from the buffer
async for frame in self._flush_complete_objects(buffer):
buffer = "" # reset after flush
yield frame
break # only one flush per iteration; rest accumulates
# Flush any remaining content
if buffer.strip():
async for frame in self._flush_complete_objects(buffer, final=True):
yield frame
yield WsStreamEnd(request_id=self.request_id)
async def _flush_complete_objects(
self, text: str, final: bool = False
) -> AsyncGenerator[WsFrame, None]:
"""Try to parse and yield all complete JSON objects from *text*.
Yields nothing if text is incomplete JSON (unless *final* is True,
in which case remaining text is emitted as plain stream_text).
"""
remaining = text.strip()
while remaining:
# Fast path: plain text (not JSON)
if not remaining.startswith("{"):
# Yield as plain text chunk
newline_idx = remaining.find("\n")
if newline_idx == -1:
if final:
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
else:
return # accumulate more
else:
line = remaining[:newline_idx].strip()
remaining = remaining[newline_idx + 1:].strip()
if line:
yield WsStreamText(request_id=self.request_id, chunk=line)
continue
# Try to decode a JSON object
try:
obj, end_idx = _try_parse_json(remaining)
except ValueError:
if final:
# Emit as raw text if we can't parse
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
return
if obj is None:
if final:
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
return # incomplete — need more tokens
remaining = remaining[end_idx:].strip()
block_type = obj.get("type")
frame = self._dispatch_block(obj, block_type)
if frame is not None:
yield frame
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
if block_type == "text":
content = obj.get("content", "")
if content:
return WsStreamText(request_id=self.request_id, chunk=str(content))
return None
if block_type == "chart":
chart_type = obj.get("chartType")
if chart_type not in _VALID_CHART_TYPES:
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
return None
if not isinstance(obj.get("data"), list):
logger.warning("HomeFormatter: chart missing data array — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="chart",
data=obj,
)
if block_type == "entity_ref":
entity = obj.get("entity")
resolved = self._resolve_entity(entity)
if resolved is None:
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="entity_ref",
data={"entity": entity, "items": resolved},
)
if block_type == "table":
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
logger.warning("HomeFormatter: table missing headers/rows — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="table",
data=obj,
)
if block_type == "timeline":
if not isinstance(obj.get("checkpoints"), list):
logger.warning("HomeFormatter: timeline missing checkpoints — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="timeline",
data=obj,
)
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
return None
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
"""Find matching items in tool_results by entity type."""
if not entity:
return None
matches = [r for r in self.tool_results if r.get("entity") == entity]
return matches if matches else None
class PopupFormatter:
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
Emits popup_domain immediately (from agent_name), then streams all tokens
as plain stream_text — no block parsing for popup context.
"""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
async def format(
self,
token_stream: AsyncGenerator[tuple[str, str], None],
) -> AsyncGenerator[WsFrame, None]:
domain_sent = False
async for agent_name, token in token_stream:
if not domain_sent:
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
yield WsPopupDomain(
request_id=self.request_id,
domain=domain, # type: ignore[arg-type]
)
yield WsStreamStart(request_id=self.request_id)
domain_sent = True
if token:
yield WsStreamText(request_id=self.request_id, chunk=token)
yield WsStreamEnd(request_id=self.request_id)
# ── helpers ───────────────────────────────────────────────────────────────────
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
"""Attempt to parse the first complete JSON object from *text*.
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
object is incomplete, and raises ``ValueError`` when text is not JSON.
"""
decoder = json.JSONDecoder()
try:
obj, end_idx = decoder.raw_decode(text)
if not isinstance(obj, dict):
raise ValueError("Expected JSON object")
return obj, end_idx
except json.JSONDecodeError as exc:
# Incomplete JSON — need more tokens
if "Unterminated" in str(exc) or exc.pos == len(text):
return None, 0
raise ValueError(str(exc)) from exc

View File

@@ -0,0 +1,195 @@
"""Tests for app.core.output_formatter — HomeFormatter and PopupFormatter."""
from __future__ import annotations
import pytest
from app.core.output_formatter import HomeFormatter, PopupFormatter
from app.schemas import (
WsPopupDomain,
WsStreamBlock,
WsStreamEnd,
WsStreamStart,
WsStreamText,
)
# ── helpers ───────────────────────────────────────────────────────────────────
async def _stream(*pairs: tuple[str, str]):
"""Async generator that yields (agent_name, token) pairs."""
for pair in pairs:
yield pair
async def collect(formatter, token_stream):
frames = []
async for frame in formatter.format(token_stream):
frames.append(frame)
return frames
# ── HomeFormatter ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_home_formatter_text_block():
req_id = "req-1"
tokens = [
("task_agent", '{"type": "text", "content": "Hello world"}'),
]
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(*tokens))
assert isinstance(frames[0], WsStreamStart)
assert frames[0].request_id == req_id
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert any("Hello world" in f.chunk for f in text_frames)
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_home_formatter_chart_block():
req_id = "req-2"
chart_json = (
'{"type": "chart", "chartType": "bar", '
'"title": "Tasks", "data": [{"x": 1}], '
'"config": {"x": {"label": "X", "color": "#fff"}}}'
)
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", chart_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "chart"
assert block_frames[0].data["chartType"] == "bar"
@pytest.mark.asyncio
async def test_home_formatter_invalid_chart_skipped():
req_id = "req-3"
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", bad_chart)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 0 # invalid chart skipped
@pytest.mark.asyncio
async def test_home_formatter_entity_ref_resolved():
req_id = "req-4"
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
entity_json = '{"type": "entity_ref", "entity": "task"}'
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
frames = await collect(formatter, _stream(("task_agent", entity_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].data["entity"] == "task"
assert block_frames[0].data["items"][0]["id"] == "t1"
@pytest.mark.asyncio
async def test_home_formatter_entity_ref_missing_skipped():
req_id = "req-5"
entity_json = '{"type": "entity_ref", "entity": "task"}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", entity_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 0 # no tool results → skipped
@pytest.mark.asyncio
async def test_home_formatter_table_block():
req_id = "req-6"
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", table_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "table"
@pytest.mark.asyncio
async def test_home_formatter_timeline_block():
req_id = "req-7"
timeline_json = '{"type": "timeline", "checkpoints": [{"id": "c1", "title": "M1", "date": 123}]}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "timeline"
@pytest.mark.asyncio
async def test_home_formatter_frame_order():
"""stream_start is first, stream_end is last."""
req_id = "req-8"
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
assert isinstance(frames[0], WsStreamStart)
assert isinstance(frames[-1], WsStreamEnd)
# ── PopupFormatter ────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_popup_formatter_domain_emitted_first():
req_id = "pop-1"
formatter = PopupFormatter(request_id=req_id)
tokens = [
("task_agent", ""), # domain signal
("task_agent", "Hello"),
("task_agent", " there"),
]
frames = await collect(formatter, _stream(*tokens))
assert isinstance(frames[0], WsPopupDomain)
assert frames[0].domain == "tasks"
assert frames[0].request_id == req_id
@pytest.mark.asyncio
async def test_popup_formatter_text_only():
req_id = "pop-2"
formatter = PopupFormatter(request_id=req_id)
tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")]
frames = await collect(formatter, _stream(*tokens))
assert isinstance(frames[0], WsPopupDomain)
assert frames[0].domain == "checkpoints"
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert len(text_frames) == 1
assert text_frames[0].chunk == "Summary"
@pytest.mark.asyncio
async def test_popup_formatter_no_block_frames():
"""PopupFormatter must never emit WsStreamBlock."""
req_id = "pop-3"
formatter = PopupFormatter(request_id=req_id)
tokens = [
("note_agent", ""),
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
]
frames = await collect(formatter, _stream(*tokens))
assert not any(isinstance(f, WsStreamBlock) for f in frames)
@pytest.mark.asyncio
async def test_popup_formatter_end_frame():
req_id = "pop-4"
formatter = PopupFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_popup_formatter_unknown_agent_defaults_to_tasks():
req_id = "pop-5"
formatter = PopupFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
assert frames[0].domain == "tasks"