"""Output Formatter — transforms deep-agent event streams into WS frame sequences. Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``: * ``("token", str)`` — supervisor text token * ``("tool_end", dict)`` — sub-agent finished: ``{name, result}`` * ``("mutations", list)`` — collected CRUD mutations for ``stream_end`` HomeFormatter: * Sniffs ``tool_end`` events → emits ``WsStreamBlock`` (entity_ref with raw data) * Streams text tokens → emits ``WsStreamText`` * Attaches mutations → injects into ``WsStreamEnd`` FloatingFormatter: * Sniffs first ``tool_end`` name → emits ``WsFloatingDomain`` * Streams text tokens → emits ``WsStreamText`` * Attaches mutations → injects into ``WsStreamEnd`` """ from __future__ import annotations import logging from collections.abc import AsyncGenerator from typing import Any from app.schemas import ( WsFloatingDomain, WsStreamBlock, WsStreamEnd, WsStreamStart, WsStreamText, ) logger = logging.getLogger(__name__) # Map sub-agent tool name → floating domain / entity type _AGENT_DOMAIN: dict[str, str] = { "task_agent": "tasks", "timeline_agent": "timelines", "note_agent": "notes", "project_agent": "projects", } WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain class HomeFormatter: """Consumes a deep-agent event stream and yields WS frames for the Home view. ``tool_end`` events from sub-agents are emitted as ``WsStreamBlock`` (entity_ref) so the client can render structured data. Text tokens are forwarded as ``WsStreamText``. Mutations are attached to ``WsStreamEnd``. """ def __init__(self, request_id: str) -> None: self.request_id = request_id self._mutations: list[dict] = [] async def format( self, event_stream: AsyncGenerator[tuple[str, Any], None], ) -> AsyncGenerator[WsFrame, None]: yield WsStreamStart(request_id=self.request_id) async for event_type, data in event_stream: if event_type == "token": if data: yield WsStreamText(request_id=self.request_id, chunk=data) elif event_type == "tool_end": # Sub-agent finished — emit its result as an entity_ref block name = data.get("name", "") entity = _AGENT_DOMAIN.get(name) if entity: yield WsStreamBlock( request_id=self.request_id, block_type="entity_ref", data={"entity": entity, "result": data.get("result", "")}, ) elif event_type == "mutations": self._mutations = data or [] yield WsStreamEnd( request_id=self.request_id, mutations=[ {"action": m["action"], "table": m["table"], "data": m["data"]} for m in self._mutations ], ) class FloatingFormatter: """Consumes a deep-agent event stream and yields WS frames for the Floating view. Sniffs the first ``tool_end`` event name to derive the domain (e.g. ``task_agent`` → ``"tasks"``), then streams text tokens as plain ``WsStreamText``. No block parsing for floating context. """ def __init__(self, request_id: str) -> None: self.request_id = request_id self._mutations: list[dict] = [] async def format( self, event_stream: AsyncGenerator[tuple[str, Any], None], ) -> AsyncGenerator[WsFrame, None]: domain_sent = False async for event_type, data in event_stream: if event_type == "tool_end" and not domain_sent: # Sniff domain from the first sub-agent that completes name = data.get("name", "") domain = _AGENT_DOMAIN.get(name, "tasks") yield WsFloatingDomain( request_id=self.request_id, domain=domain, # type: ignore[arg-type] ) yield WsStreamStart(request_id=self.request_id) domain_sent = True elif event_type == "token": if not domain_sent: # First token arrived before any tool_end — default domain yield WsFloatingDomain( request_id=self.request_id, domain="tasks", # type: ignore[arg-type] ) yield WsStreamStart(request_id=self.request_id) domain_sent = True if data: yield WsStreamText(request_id=self.request_id, chunk=data) elif event_type == "mutations": self._mutations = data or [] # If no events triggered domain_sent (edge case), still emit structure if not domain_sent: yield WsFloatingDomain( request_id=self.request_id, domain="tasks", # type: ignore[arg-type] ) yield WsStreamStart(request_id=self.request_id) yield WsStreamEnd( request_id=self.request_id, mutations=[ {"action": m["action"], "table": m["table"], "data": m["data"]} for m in self._mutations ], )