"""Output Formatter — transforms orchestrator token streams into WS frame sequences. HomeFormatter: produces stream_start, stream_text / stream_block, stream_end FloatingFormatter: produces floating_domain, stream_text, stream_end """ from __future__ import annotations import json import logging from collections.abc import AsyncGenerator from typing import Any from app.schemas import ( WsFloatingDomain, WsStreamBlock, WsStreamEnd, WsStreamStart, WsStreamText, ) logger = logging.getLogger(__name__) # Valid chart types (matching shadcn/ui Recharts wrappers in Electron) _VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"} # Map agent name → floating domain _AGENT_DOMAIN: dict[str, str] = { "task_agent": "tasks", "checkpoint_agent": "checkpoints", "note_agent": "notes", "project_agent": "projects", } WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain class HomeFormatter: """Parses a token stream from orchestrate_v3_stream and yields WS frames. The LLM is expected to output a newline-delimited sequence of JSON objects, each with a ``type`` field: - ``text`` → yields WsStreamText immediately (word-by-word) - ``chart`` → buffers full JSON, validates, yields WsStreamBlock - ``entity_ref`` → resolves from tool_results, yields WsStreamBlock - ``table`` → buffers full JSON, validates, yields WsStreamBlock - ``timeline`` → buffers full JSON, validates, yields WsStreamBlock Invalid or unknown blocks are logged and skipped — stream never crashes. """ def __init__(self, request_id: str, tool_results: list[dict]) -> None: self.request_id = request_id self.tool_results = tool_results async def format( self, token_stream: AsyncGenerator[tuple[str, str], None], ) -> AsyncGenerator[WsFrame, None]: yield WsStreamStart(request_id=self.request_id) buffer = "" async for _agent_name, token in token_stream: if not token: continue buffer += token # Flush any complete JSON objects from the buffer async for frame in self._flush_complete_objects(buffer): buffer = "" # reset after flush yield frame break # only one flush per iteration; rest accumulates # Flush any remaining content if buffer.strip(): async for frame in self._flush_complete_objects(buffer, final=True): yield frame yield WsStreamEnd(request_id=self.request_id) async def _flush_complete_objects( self, text: str, final: bool = False ) -> AsyncGenerator[WsFrame, None]: """Try to parse and yield all complete JSON objects from *text*. Yields nothing if text is incomplete JSON (unless *final* is True, in which case remaining text is emitted as plain stream_text). """ remaining = text.strip() while remaining: # Fast path: plain text (not JSON) if not remaining.startswith("{"): # Yield as plain text chunk newline_idx = remaining.find("\n") if newline_idx == -1: if final: yield WsStreamText(request_id=self.request_id, chunk=remaining) remaining = "" else: return # accumulate more else: line = remaining[:newline_idx].strip() remaining = remaining[newline_idx + 1:].strip() if line: yield WsStreamText(request_id=self.request_id, chunk=line) continue # Try to decode a JSON object try: obj, end_idx = _try_parse_json(remaining) except ValueError: if final: # Emit as raw text if we can't parse yield WsStreamText(request_id=self.request_id, chunk=remaining) remaining = "" return if obj is None: if final: yield WsStreamText(request_id=self.request_id, chunk=remaining) remaining = "" return # incomplete — need more tokens remaining = remaining[end_idx:].strip() block_type = obj.get("type") frame = self._dispatch_block(obj, block_type) if frame is not None: yield frame def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None: if block_type == "text": content = obj.get("content", "") if content: return WsStreamText(request_id=self.request_id, chunk=str(content)) return None if block_type == "chart": chart_type = obj.get("chartType") if chart_type not in _VALID_CHART_TYPES: logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type) return None if not isinstance(obj.get("data"), list): logger.warning("HomeFormatter: chart missing data array — skipping") return None return WsStreamBlock( request_id=self.request_id, block_type="chart", data=obj, ) if block_type == "entity_ref": entity = obj.get("entity") resolved = self._resolve_entity(entity) if resolved is None: logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity) return None return WsStreamBlock( request_id=self.request_id, block_type="entity_ref", data={"entity": entity, "items": resolved}, ) if block_type == "table": if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list): logger.warning("HomeFormatter: table missing headers/rows — skipping") return None return WsStreamBlock( request_id=self.request_id, block_type="table", data=obj, ) if block_type == "timeline": if not isinstance(obj.get("checkpoints"), list): logger.warning("HomeFormatter: timeline missing checkpoints — skipping") return None return WsStreamBlock( request_id=self.request_id, block_type="timeline", data=obj, ) logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type) return None def _resolve_entity(self, entity: str | None) -> list[dict] | None: """Find matching items in tool_results by entity type.""" if not entity: return None matches = [r for r in self.tool_results if r.get("entity") == entity] return matches if matches else None class FloatingFormatter: """Parses a token stream from orchestrate_v3_stream and yields WS frames. Emits floating_domain immediately (from agent_name), then streams all tokens as plain stream_text — no block parsing for floating context. """ def __init__(self, request_id: str) -> None: self.request_id = request_id async def format( self, token_stream: AsyncGenerator[tuple[str, str], None], ) -> AsyncGenerator[WsFrame, None]: domain_sent = False async for agent_name, token in token_stream: if not domain_sent: domain = _AGENT_DOMAIN.get(agent_name, "tasks") yield WsFloatingDomain( request_id=self.request_id, domain=domain, # type: ignore[arg-type] ) yield WsStreamStart(request_id=self.request_id) domain_sent = True if token: yield WsStreamText(request_id=self.request_id, chunk=token) yield WsStreamEnd(request_id=self.request_id) # ── helpers ─────────────────────────────────────────────────────────────────── def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]: """Attempt to parse the first complete JSON object from *text*. Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the object is incomplete, and raises ``ValueError`` when text is not JSON. """ decoder = json.JSONDecoder() try: obj, end_idx = decoder.raw_decode(text) if not isinstance(obj, dict): raise ValueError("Expected JSON object") return obj, end_idx except json.JSONDecodeError as exc: # Incomplete JSON — need more tokens if "Unterminated" in str(exc) or exc.pos == len(text): return None, 0 raise ValueError(str(exc)) from exc