step-4: add output formatting layer (output_formatter.py)
HomeFormatter parses JSON block stream from orchestrator tokens and emits stream_start / stream_text / stream_block / stream_end frames. PopupFormatter emits popup_domain then plain stream_text. All 13 unit tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -212,7 +212,7 @@ pytest tests/test_output_formatter.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Status**:
|
**Status**:
|
||||||
- [ ] Step 4 complete
|
- [x] Step 4 complete
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
**Commit**: After tests pass, commit with:
|
||||||
```
|
```
|
||||||
|
|||||||
244
app/core/output_formatter.py
Normal file
244
app/core/output_formatter.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
|
||||||
|
|
||||||
|
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
|
||||||
|
PopupFormatter: produces popup_domain, stream_text, stream_end
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.schemas import (
|
||||||
|
WsPopupDomain,
|
||||||
|
WsStreamBlock,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
|
||||||
|
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
|
||||||
|
|
||||||
|
# Map agent name → popup domain
|
||||||
|
_AGENT_DOMAIN: dict[str, str] = {
|
||||||
|
"task_agent": "tasks",
|
||||||
|
"checkpoint_agent": "checkpoints",
|
||||||
|
"note_agent": "notes",
|
||||||
|
"project_agent": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
|
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsPopupDomain
|
||||||
|
|
||||||
|
|
||||||
|
class HomeFormatter:
|
||||||
|
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||||
|
|
||||||
|
The LLM is expected to output a newline-delimited sequence of JSON objects,
|
||||||
|
each with a ``type`` field:
|
||||||
|
- ``text`` → yields WsStreamText immediately (word-by-word)
|
||||||
|
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
|
||||||
|
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
|
||||||
|
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
|
||||||
|
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
|
||||||
|
|
||||||
|
Invalid or unknown blocks are logged and skipped — stream never crashes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self.tool_results = tool_results
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
async for _agent_name, token in token_stream:
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
buffer += token
|
||||||
|
# Flush any complete JSON objects from the buffer
|
||||||
|
async for frame in self._flush_complete_objects(buffer):
|
||||||
|
buffer = "" # reset after flush
|
||||||
|
yield frame
|
||||||
|
break # only one flush per iteration; rest accumulates
|
||||||
|
|
||||||
|
# Flush any remaining content
|
||||||
|
if buffer.strip():
|
||||||
|
async for frame in self._flush_complete_objects(buffer, final=True):
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
|
|
||||||
|
async def _flush_complete_objects(
|
||||||
|
self, text: str, final: bool = False
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
"""Try to parse and yield all complete JSON objects from *text*.
|
||||||
|
|
||||||
|
Yields nothing if text is incomplete JSON (unless *final* is True,
|
||||||
|
in which case remaining text is emitted as plain stream_text).
|
||||||
|
"""
|
||||||
|
remaining = text.strip()
|
||||||
|
while remaining:
|
||||||
|
# Fast path: plain text (not JSON)
|
||||||
|
if not remaining.startswith("{"):
|
||||||
|
# Yield as plain text chunk
|
||||||
|
newline_idx = remaining.find("\n")
|
||||||
|
if newline_idx == -1:
|
||||||
|
if final:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||||
|
remaining = ""
|
||||||
|
else:
|
||||||
|
return # accumulate more
|
||||||
|
else:
|
||||||
|
line = remaining[:newline_idx].strip()
|
||||||
|
remaining = remaining[newline_idx + 1:].strip()
|
||||||
|
if line:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to decode a JSON object
|
||||||
|
try:
|
||||||
|
obj, end_idx = _try_parse_json(remaining)
|
||||||
|
except ValueError:
|
||||||
|
if final:
|
||||||
|
# Emit as raw text if we can't parse
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||||
|
remaining = ""
|
||||||
|
return
|
||||||
|
|
||||||
|
if obj is None:
|
||||||
|
if final:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||||
|
remaining = ""
|
||||||
|
return # incomplete — need more tokens
|
||||||
|
|
||||||
|
remaining = remaining[end_idx:].strip()
|
||||||
|
block_type = obj.get("type")
|
||||||
|
|
||||||
|
frame = self._dispatch_block(obj, block_type)
|
||||||
|
if frame is not None:
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
|
||||||
|
if block_type == "text":
|
||||||
|
content = obj.get("content", "")
|
||||||
|
if content:
|
||||||
|
return WsStreamText(request_id=self.request_id, chunk=str(content))
|
||||||
|
return None
|
||||||
|
|
||||||
|
if block_type == "chart":
|
||||||
|
chart_type = obj.get("chartType")
|
||||||
|
if chart_type not in _VALID_CHART_TYPES:
|
||||||
|
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
|
||||||
|
return None
|
||||||
|
if not isinstance(obj.get("data"), list):
|
||||||
|
logger.warning("HomeFormatter: chart missing data array — skipping")
|
||||||
|
return None
|
||||||
|
return WsStreamBlock(
|
||||||
|
request_id=self.request_id,
|
||||||
|
block_type="chart",
|
||||||
|
data=obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if block_type == "entity_ref":
|
||||||
|
entity = obj.get("entity")
|
||||||
|
resolved = self._resolve_entity(entity)
|
||||||
|
if resolved is None:
|
||||||
|
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
|
||||||
|
return None
|
||||||
|
return WsStreamBlock(
|
||||||
|
request_id=self.request_id,
|
||||||
|
block_type="entity_ref",
|
||||||
|
data={"entity": entity, "items": resolved},
|
||||||
|
)
|
||||||
|
|
||||||
|
if block_type == "table":
|
||||||
|
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
|
||||||
|
logger.warning("HomeFormatter: table missing headers/rows — skipping")
|
||||||
|
return None
|
||||||
|
return WsStreamBlock(
|
||||||
|
request_id=self.request_id,
|
||||||
|
block_type="table",
|
||||||
|
data=obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if block_type == "timeline":
|
||||||
|
if not isinstance(obj.get("checkpoints"), list):
|
||||||
|
logger.warning("HomeFormatter: timeline missing checkpoints — skipping")
|
||||||
|
return None
|
||||||
|
return WsStreamBlock(
|
||||||
|
request_id=self.request_id,
|
||||||
|
block_type="timeline",
|
||||||
|
data=obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
|
||||||
|
"""Find matching items in tool_results by entity type."""
|
||||||
|
if not entity:
|
||||||
|
return None
|
||||||
|
matches = [r for r in self.tool_results if r.get("entity") == entity]
|
||||||
|
return matches if matches else None
|
||||||
|
|
||||||
|
|
||||||
|
class PopupFormatter:
|
||||||
|
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||||
|
|
||||||
|
Emits popup_domain immediately (from agent_name), then streams all tokens
|
||||||
|
as plain stream_text — no block parsing for popup context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
domain_sent = False
|
||||||
|
|
||||||
|
async for agent_name, token in token_stream:
|
||||||
|
if not domain_sent:
|
||||||
|
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
|
||||||
|
yield WsPopupDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain=domain, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
|
||||||
|
if token:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=token)
|
||||||
|
|
||||||
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
|
||||||
|
"""Attempt to parse the first complete JSON object from *text*.
|
||||||
|
|
||||||
|
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
|
||||||
|
object is incomplete, and raises ``ValueError`` when text is not JSON.
|
||||||
|
"""
|
||||||
|
decoder = json.JSONDecoder()
|
||||||
|
try:
|
||||||
|
obj, end_idx = decoder.raw_decode(text)
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
raise ValueError("Expected JSON object")
|
||||||
|
return obj, end_idx
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
# Incomplete JSON — need more tokens
|
||||||
|
if "Unterminated" in str(exc) or exc.pos == len(text):
|
||||||
|
return None, 0
|
||||||
|
raise ValueError(str(exc)) from exc
|
||||||
195
tests/test_output_formatter.py
Normal file
195
tests/test_output_formatter.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Tests for app.core.output_formatter — HomeFormatter and PopupFormatter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.output_formatter import HomeFormatter, PopupFormatter
|
||||||
|
from app.schemas import (
|
||||||
|
WsPopupDomain,
|
||||||
|
WsStreamBlock,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _stream(*pairs: tuple[str, str]):
|
||||||
|
"""Async generator that yields (agent_name, token) pairs."""
|
||||||
|
for pair in pairs:
|
||||||
|
yield pair
|
||||||
|
|
||||||
|
|
||||||
|
async def collect(formatter, token_stream):
|
||||||
|
frames = []
|
||||||
|
async for frame in formatter.format(token_stream):
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_text_block():
|
||||||
|
req_id = "req-1"
|
||||||
|
tokens = [
|
||||||
|
("task_agent", '{"type": "text", "content": "Hello world"}'),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(*tokens))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert frames[0].request_id == req_id
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert any("Hello world" in f.chunk for f in text_frames)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_chart_block():
|
||||||
|
req_id = "req-2"
|
||||||
|
chart_json = (
|
||||||
|
'{"type": "chart", "chartType": "bar", '
|
||||||
|
'"title": "Tasks", "data": [{"x": 1}], '
|
||||||
|
'"config": {"x": {"label": "X", "color": "#fff"}}}'
|
||||||
|
)
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", chart_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].block_type == "chart"
|
||||||
|
assert block_frames[0].data["chartType"] == "bar"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_invalid_chart_skipped():
|
||||||
|
req_id = "req-3"
|
||||||
|
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", bad_chart)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 0 # invalid chart skipped
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_entity_ref_resolved():
|
||||||
|
req_id = "req-4"
|
||||||
|
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
|
||||||
|
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].data["entity"] == "task"
|
||||||
|
assert block_frames[0].data["items"][0]["id"] == "t1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_entity_ref_missing_skipped():
|
||||||
|
req_id = "req-5"
|
||||||
|
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 0 # no tool results → skipped
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_table_block():
|
||||||
|
req_id = "req-6"
|
||||||
|
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", table_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].block_type == "table"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_timeline_block():
|
||||||
|
req_id = "req-7"
|
||||||
|
timeline_json = '{"type": "timeline", "checkpoints": [{"id": "c1", "title": "M1", "date": 123}]}'
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
|
||||||
|
|
||||||
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||||
|
assert len(block_frames) == 1
|
||||||
|
assert block_frames[0].block_type == "timeline"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_frame_order():
|
||||||
|
"""stream_start is first, stream_end is last."""
|
||||||
|
req_id = "req-8"
|
||||||
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||||
|
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
# ── PopupFormatter ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_popup_formatter_domain_emitted_first():
|
||||||
|
req_id = "pop-1"
|
||||||
|
formatter = PopupFormatter(request_id=req_id)
|
||||||
|
tokens = [
|
||||||
|
("task_agent", ""), # domain signal
|
||||||
|
("task_agent", "Hello"),
|
||||||
|
("task_agent", " there"),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*tokens))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsPopupDomain)
|
||||||
|
assert frames[0].domain == "tasks"
|
||||||
|
assert frames[0].request_id == req_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_popup_formatter_text_only():
|
||||||
|
req_id = "pop-2"
|
||||||
|
formatter = PopupFormatter(request_id=req_id)
|
||||||
|
tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")]
|
||||||
|
frames = await collect(formatter, _stream(*tokens))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsPopupDomain)
|
||||||
|
assert frames[0].domain == "checkpoints"
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert len(text_frames) == 1
|
||||||
|
assert text_frames[0].chunk == "Summary"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_popup_formatter_no_block_frames():
|
||||||
|
"""PopupFormatter must never emit WsStreamBlock."""
|
||||||
|
req_id = "pop-3"
|
||||||
|
formatter = PopupFormatter(request_id=req_id)
|
||||||
|
tokens = [
|
||||||
|
("note_agent", ""),
|
||||||
|
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*tokens))
|
||||||
|
assert not any(isinstance(f, WsStreamBlock) for f in frames)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_popup_formatter_end_frame():
|
||||||
|
req_id = "pop-4"
|
||||||
|
formatter = PopupFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_popup_formatter_unknown_agent_defaults_to_tasks():
|
||||||
|
req_id = "pop-5"
|
||||||
|
formatter = PopupFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
|
||||||
|
assert frames[0].domain == "tasks"
|
||||||
Reference in New Issue
Block a user