step-4: add output formatting layer (output_formatter.py)
HomeFormatter parses JSON block stream from orchestrator tokens and emits stream_start / stream_text / stream_block / stream_end frames. PopupFormatter emits popup_domain then plain stream_text. All 13 unit tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -212,7 +212,7 @@ pytest tests/test_output_formatter.py
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [ ] Step 4 complete
|
||||
- [x] Step 4 complete
|
||||
|
||||
**Commit**: After tests pass, commit with:
|
||||
```
|
||||
|
||||
244
app/core/output_formatter.py
Normal file
244
app/core/output_formatter.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
|
||||
|
||||
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
|
||||
PopupFormatter: produces popup_domain, stream_text, stream_end
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import (
|
||||
WsPopupDomain,
|
||||
WsStreamBlock,
|
||||
WsStreamEnd,
|
||||
WsStreamStart,
|
||||
WsStreamText,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
|
||||
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
|
||||
|
||||
# Map agent name → popup domain
|
||||
_AGENT_DOMAIN: dict[str, str] = {
|
||||
"task_agent": "tasks",
|
||||
"checkpoint_agent": "checkpoints",
|
||||
"note_agent": "notes",
|
||||
"project_agent": "projects",
|
||||
}
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsPopupDomain
|
||||
|
||||
|
||||
class HomeFormatter:
|
||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||
|
||||
The LLM is expected to output a newline-delimited sequence of JSON objects,
|
||||
each with a ``type`` field:
|
||||
- ``text`` → yields WsStreamText immediately (word-by-word)
|
||||
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
|
||||
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
|
||||
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
|
||||
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
|
||||
|
||||
Invalid or unknown blocks are logged and skipped — stream never crashes.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
|
||||
self.request_id = request_id
|
||||
self.tool_results = tool_results
|
||||
|
||||
async def format(
|
||||
self,
|
||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
|
||||
buffer = ""
|
||||
async for _agent_name, token in token_stream:
|
||||
if not token:
|
||||
continue
|
||||
buffer += token
|
||||
# Flush any complete JSON objects from the buffer
|
||||
async for frame in self._flush_complete_objects(buffer):
|
||||
buffer = "" # reset after flush
|
||||
yield frame
|
||||
break # only one flush per iteration; rest accumulates
|
||||
|
||||
# Flush any remaining content
|
||||
if buffer.strip():
|
||||
async for frame in self._flush_complete_objects(buffer, final=True):
|
||||
yield frame
|
||||
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
|
||||
async def _flush_complete_objects(
|
||||
self, text: str, final: bool = False
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
"""Try to parse and yield all complete JSON objects from *text*.
|
||||
|
||||
Yields nothing if text is incomplete JSON (unless *final* is True,
|
||||
in which case remaining text is emitted as plain stream_text).
|
||||
"""
|
||||
remaining = text.strip()
|
||||
while remaining:
|
||||
# Fast path: plain text (not JSON)
|
||||
if not remaining.startswith("{"):
|
||||
# Yield as plain text chunk
|
||||
newline_idx = remaining.find("\n")
|
||||
if newline_idx == -1:
|
||||
if final:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||
remaining = ""
|
||||
else:
|
||||
return # accumulate more
|
||||
else:
|
||||
line = remaining[:newline_idx].strip()
|
||||
remaining = remaining[newline_idx + 1:].strip()
|
||||
if line:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=line)
|
||||
continue
|
||||
|
||||
# Try to decode a JSON object
|
||||
try:
|
||||
obj, end_idx = _try_parse_json(remaining)
|
||||
except ValueError:
|
||||
if final:
|
||||
# Emit as raw text if we can't parse
|
||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||
remaining = ""
|
||||
return
|
||||
|
||||
if obj is None:
|
||||
if final:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||
remaining = ""
|
||||
return # incomplete — need more tokens
|
||||
|
||||
remaining = remaining[end_idx:].strip()
|
||||
block_type = obj.get("type")
|
||||
|
||||
frame = self._dispatch_block(obj, block_type)
|
||||
if frame is not None:
|
||||
yield frame
|
||||
|
||||
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
|
||||
if block_type == "text":
|
||||
content = obj.get("content", "")
|
||||
if content:
|
||||
return WsStreamText(request_id=self.request_id, chunk=str(content))
|
||||
return None
|
||||
|
||||
if block_type == "chart":
|
||||
chart_type = obj.get("chartType")
|
||||
if chart_type not in _VALID_CHART_TYPES:
|
||||
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
|
||||
return None
|
||||
if not isinstance(obj.get("data"), list):
|
||||
logger.warning("HomeFormatter: chart missing data array — skipping")
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="chart",
|
||||
data=obj,
|
||||
)
|
||||
|
||||
if block_type == "entity_ref":
|
||||
entity = obj.get("entity")
|
||||
resolved = self._resolve_entity(entity)
|
||||
if resolved is None:
|
||||
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="entity_ref",
|
||||
data={"entity": entity, "items": resolved},
|
||||
)
|
||||
|
||||
if block_type == "table":
|
||||
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
|
||||
logger.warning("HomeFormatter: table missing headers/rows — skipping")
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="table",
|
||||
data=obj,
|
||||
)
|
||||
|
||||
if block_type == "timeline":
|
||||
if not isinstance(obj.get("checkpoints"), list):
|
||||
logger.warning("HomeFormatter: timeline missing checkpoints — skipping")
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="timeline",
|
||||
data=obj,
|
||||
)
|
||||
|
||||
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
|
||||
return None
|
||||
|
||||
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
|
||||
"""Find matching items in tool_results by entity type."""
|
||||
if not entity:
|
||||
return None
|
||||
matches = [r for r in self.tool_results if r.get("entity") == entity]
|
||||
return matches if matches else None
|
||||
|
||||
|
||||
class PopupFormatter:
|
||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||
|
||||
Emits popup_domain immediately (from agent_name), then streams all tokens
|
||||
as plain stream_text — no block parsing for popup context.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
|
||||
async def format(
|
||||
self,
|
||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
domain_sent = False
|
||||
|
||||
async for agent_name, token in token_stream:
|
||||
if not domain_sent:
|
||||
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
|
||||
yield WsPopupDomain(
|
||||
request_id=self.request_id,
|
||||
domain=domain, # type: ignore[arg-type]
|
||||
)
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
domain_sent = True
|
||||
|
||||
if token:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=token)
|
||||
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
|
||||
"""Attempt to parse the first complete JSON object from *text*.
|
||||
|
||||
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
|
||||
object is incomplete, and raises ``ValueError`` when text is not JSON.
|
||||
"""
|
||||
decoder = json.JSONDecoder()
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(text)
|
||||
if not isinstance(obj, dict):
|
||||
raise ValueError("Expected JSON object")
|
||||
return obj, end_idx
|
||||
except json.JSONDecodeError as exc:
|
||||
# Incomplete JSON — need more tokens
|
||||
if "Unterminated" in str(exc) or exc.pos == len(text):
|
||||
return None, 0
|
||||
raise ValueError(str(exc)) from exc
|
||||
195
tests/test_output_formatter.py
Normal file
195
tests/test_output_formatter.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Tests for app.core.output_formatter — HomeFormatter and PopupFormatter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.output_formatter import HomeFormatter, PopupFormatter
|
||||
from app.schemas import (
|
||||
WsPopupDomain,
|
||||
WsStreamBlock,
|
||||
WsStreamEnd,
|
||||
WsStreamStart,
|
||||
WsStreamText,
|
||||
)
|
||||
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _stream(*pairs: tuple[str, str]):
|
||||
"""Async generator that yields (agent_name, token) pairs."""
|
||||
for pair in pairs:
|
||||
yield pair
|
||||
|
||||
|
||||
async def collect(formatter, token_stream):
|
||||
frames = []
|
||||
async for frame in formatter.format(token_stream):
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
|
||||
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_text_block():
|
||||
req_id = "req-1"
|
||||
tokens = [
|
||||
("task_agent", '{"type": "text", "content": "Hello world"}'),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(*tokens))
|
||||
|
||||
assert isinstance(frames[0], WsStreamStart)
|
||||
assert frames[0].request_id == req_id
|
||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||
assert any("Hello world" in f.chunk for f in text_frames)
|
||||
assert isinstance(frames[-1], WsStreamEnd)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_chart_block():
|
||||
req_id = "req-2"
|
||||
chart_json = (
|
||||
'{"type": "chart", "chartType": "bar", '
|
||||
'"title": "Tasks", "data": [{"x": 1}], '
|
||||
'"config": {"x": {"label": "X", "color": "#fff"}}}'
|
||||
)
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", chart_json)))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].block_type == "chart"
|
||||
assert block_frames[0].data["chartType"] == "bar"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_invalid_chart_skipped():
|
||||
req_id = "req-3"
|
||||
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", bad_chart)))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 0 # invalid chart skipped
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_entity_ref_resolved():
|
||||
req_id = "req-4"
|
||||
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
|
||||
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
|
||||
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].data["entity"] == "task"
|
||||
assert block_frames[0].data["items"][0]["id"] == "t1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_entity_ref_missing_skipped():
|
||||
req_id = "req-5"
|
||||
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 0 # no tool results → skipped
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_table_block():
|
||||
req_id = "req-6"
|
||||
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", table_json)))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].block_type == "table"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_timeline_block():
|
||||
req_id = "req-7"
|
||||
timeline_json = '{"type": "timeline", "checkpoints": [{"id": "c1", "title": "M1", "date": 123}]}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].block_type == "timeline"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_frame_order():
|
||||
"""stream_start is first, stream_end is last."""
|
||||
req_id = "req-8"
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
|
||||
assert isinstance(frames[0], WsStreamStart)
|
||||
assert isinstance(frames[-1], WsStreamEnd)
|
||||
|
||||
|
||||
# ── PopupFormatter ────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_popup_formatter_domain_emitted_first():
|
||||
req_id = "pop-1"
|
||||
formatter = PopupFormatter(request_id=req_id)
|
||||
tokens = [
|
||||
("task_agent", ""), # domain signal
|
||||
("task_agent", "Hello"),
|
||||
("task_agent", " there"),
|
||||
]
|
||||
frames = await collect(formatter, _stream(*tokens))
|
||||
|
||||
assert isinstance(frames[0], WsPopupDomain)
|
||||
assert frames[0].domain == "tasks"
|
||||
assert frames[0].request_id == req_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_popup_formatter_text_only():
|
||||
req_id = "pop-2"
|
||||
formatter = PopupFormatter(request_id=req_id)
|
||||
tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")]
|
||||
frames = await collect(formatter, _stream(*tokens))
|
||||
|
||||
assert isinstance(frames[0], WsPopupDomain)
|
||||
assert frames[0].domain == "checkpoints"
|
||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||
assert len(text_frames) == 1
|
||||
assert text_frames[0].chunk == "Summary"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_popup_formatter_no_block_frames():
|
||||
"""PopupFormatter must never emit WsStreamBlock."""
|
||||
req_id = "pop-3"
|
||||
formatter = PopupFormatter(request_id=req_id)
|
||||
tokens = [
|
||||
("note_agent", ""),
|
||||
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
|
||||
]
|
||||
frames = await collect(formatter, _stream(*tokens))
|
||||
assert not any(isinstance(f, WsStreamBlock) for f in frames)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_popup_formatter_end_frame():
|
||||
req_id = "pop-4"
|
||||
formatter = PopupFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
|
||||
assert isinstance(frames[-1], WsStreamEnd)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_popup_formatter_unknown_agent_defaults_to_tasks():
|
||||
req_id = "pop-5"
|
||||
formatter = PopupFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
|
||||
assert frames[0].domain == "tasks"
|
||||
Reference in New Issue
Block a user