HomeFormatter parses JSON block stream from orchestrator tokens and emits stream_start / stream_text / stream_block / stream_end frames. PopupFormatter emits popup_domain then plain stream_text. All 13 unit tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
245 lines
9.0 KiB
Python
245 lines
9.0 KiB
Python
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
|
|
|
|
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
|
|
PopupFormatter: produces popup_domain, stream_text, stream_end
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
from app.schemas import (
|
|
WsPopupDomain,
|
|
WsStreamBlock,
|
|
WsStreamEnd,
|
|
WsStreamStart,
|
|
WsStreamText,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
|
|
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
|
|
|
|
# Map agent name → popup domain
|
|
_AGENT_DOMAIN: dict[str, str] = {
|
|
"task_agent": "tasks",
|
|
"checkpoint_agent": "checkpoints",
|
|
"note_agent": "notes",
|
|
"project_agent": "projects",
|
|
}
|
|
|
|
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsPopupDomain
|
|
|
|
|
|
class HomeFormatter:
|
|
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
|
|
|
The LLM is expected to output a newline-delimited sequence of JSON objects,
|
|
each with a ``type`` field:
|
|
- ``text`` → yields WsStreamText immediately (word-by-word)
|
|
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
|
|
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
|
|
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
|
|
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
|
|
|
|
Invalid or unknown blocks are logged and skipped — stream never crashes.
|
|
"""
|
|
|
|
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
|
|
self.request_id = request_id
|
|
self.tool_results = tool_results
|
|
|
|
async def format(
|
|
self,
|
|
token_stream: AsyncGenerator[tuple[str, str], None],
|
|
) -> AsyncGenerator[WsFrame, None]:
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
|
|
buffer = ""
|
|
async for _agent_name, token in token_stream:
|
|
if not token:
|
|
continue
|
|
buffer += token
|
|
# Flush any complete JSON objects from the buffer
|
|
async for frame in self._flush_complete_objects(buffer):
|
|
buffer = "" # reset after flush
|
|
yield frame
|
|
break # only one flush per iteration; rest accumulates
|
|
|
|
# Flush any remaining content
|
|
if buffer.strip():
|
|
async for frame in self._flush_complete_objects(buffer, final=True):
|
|
yield frame
|
|
|
|
yield WsStreamEnd(request_id=self.request_id)
|
|
|
|
async def _flush_complete_objects(
|
|
self, text: str, final: bool = False
|
|
) -> AsyncGenerator[WsFrame, None]:
|
|
"""Try to parse and yield all complete JSON objects from *text*.
|
|
|
|
Yields nothing if text is incomplete JSON (unless *final* is True,
|
|
in which case remaining text is emitted as plain stream_text).
|
|
"""
|
|
remaining = text.strip()
|
|
while remaining:
|
|
# Fast path: plain text (not JSON)
|
|
if not remaining.startswith("{"):
|
|
# Yield as plain text chunk
|
|
newline_idx = remaining.find("\n")
|
|
if newline_idx == -1:
|
|
if final:
|
|
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
remaining = ""
|
|
else:
|
|
return # accumulate more
|
|
else:
|
|
line = remaining[:newline_idx].strip()
|
|
remaining = remaining[newline_idx + 1:].strip()
|
|
if line:
|
|
yield WsStreamText(request_id=self.request_id, chunk=line)
|
|
continue
|
|
|
|
# Try to decode a JSON object
|
|
try:
|
|
obj, end_idx = _try_parse_json(remaining)
|
|
except ValueError:
|
|
if final:
|
|
# Emit as raw text if we can't parse
|
|
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
remaining = ""
|
|
return
|
|
|
|
if obj is None:
|
|
if final:
|
|
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
remaining = ""
|
|
return # incomplete — need more tokens
|
|
|
|
remaining = remaining[end_idx:].strip()
|
|
block_type = obj.get("type")
|
|
|
|
frame = self._dispatch_block(obj, block_type)
|
|
if frame is not None:
|
|
yield frame
|
|
|
|
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
|
|
if block_type == "text":
|
|
content = obj.get("content", "")
|
|
if content:
|
|
return WsStreamText(request_id=self.request_id, chunk=str(content))
|
|
return None
|
|
|
|
if block_type == "chart":
|
|
chart_type = obj.get("chartType")
|
|
if chart_type not in _VALID_CHART_TYPES:
|
|
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
|
|
return None
|
|
if not isinstance(obj.get("data"), list):
|
|
logger.warning("HomeFormatter: chart missing data array — skipping")
|
|
return None
|
|
return WsStreamBlock(
|
|
request_id=self.request_id,
|
|
block_type="chart",
|
|
data=obj,
|
|
)
|
|
|
|
if block_type == "entity_ref":
|
|
entity = obj.get("entity")
|
|
resolved = self._resolve_entity(entity)
|
|
if resolved is None:
|
|
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
|
|
return None
|
|
return WsStreamBlock(
|
|
request_id=self.request_id,
|
|
block_type="entity_ref",
|
|
data={"entity": entity, "items": resolved},
|
|
)
|
|
|
|
if block_type == "table":
|
|
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
|
|
logger.warning("HomeFormatter: table missing headers/rows — skipping")
|
|
return None
|
|
return WsStreamBlock(
|
|
request_id=self.request_id,
|
|
block_type="table",
|
|
data=obj,
|
|
)
|
|
|
|
if block_type == "timeline":
|
|
if not isinstance(obj.get("checkpoints"), list):
|
|
logger.warning("HomeFormatter: timeline missing checkpoints — skipping")
|
|
return None
|
|
return WsStreamBlock(
|
|
request_id=self.request_id,
|
|
block_type="timeline",
|
|
data=obj,
|
|
)
|
|
|
|
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
|
|
return None
|
|
|
|
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
|
|
"""Find matching items in tool_results by entity type."""
|
|
if not entity:
|
|
return None
|
|
matches = [r for r in self.tool_results if r.get("entity") == entity]
|
|
return matches if matches else None
|
|
|
|
|
|
class PopupFormatter:
|
|
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
|
|
|
Emits popup_domain immediately (from agent_name), then streams all tokens
|
|
as plain stream_text — no block parsing for popup context.
|
|
"""
|
|
|
|
def __init__(self, request_id: str) -> None:
|
|
self.request_id = request_id
|
|
|
|
async def format(
|
|
self,
|
|
token_stream: AsyncGenerator[tuple[str, str], None],
|
|
) -> AsyncGenerator[WsFrame, None]:
|
|
domain_sent = False
|
|
|
|
async for agent_name, token in token_stream:
|
|
if not domain_sent:
|
|
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
|
|
yield WsPopupDomain(
|
|
request_id=self.request_id,
|
|
domain=domain, # type: ignore[arg-type]
|
|
)
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
domain_sent = True
|
|
|
|
if token:
|
|
yield WsStreamText(request_id=self.request_id, chunk=token)
|
|
|
|
yield WsStreamEnd(request_id=self.request_id)
|
|
|
|
|
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
|
|
"""Attempt to parse the first complete JSON object from *text*.
|
|
|
|
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
|
|
object is incomplete, and raises ``ValueError`` when text is not JSON.
|
|
"""
|
|
decoder = json.JSONDecoder()
|
|
try:
|
|
obj, end_idx = decoder.raw_decode(text)
|
|
if not isinstance(obj, dict):
|
|
raise ValueError("Expected JSON object")
|
|
return obj, end_idx
|
|
except json.JSONDecodeError as exc:
|
|
# Incomplete JSON — need more tokens
|
|
if "Unterminated" in str(exc) or exc.pos == len(text):
|
|
return None, 0
|
|
raise ValueError(str(exc)) from exc
|