"""Output Formatter — transforms deep-agent event streams into WS frame sequences. Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``: * ``("token", str)`` — supervisor text token * ``("tool_end", dict)`` — sub-agent finished: ``{name, result}`` * ``("mutations", list)`` — collected CRUD mutations for ``stream_end`` HomeFormatter: * Buffers text tokens and parses inline entity tags ``[id1,id2]`` → emits ``WsStreamBlock`` (entity_ref with IDs) * Streams surrounding text → emits ``WsStreamText`` * Attaches mutations → injects into ``WsStreamEnd`` FloatingFormatter: * Sniffs first ``tool_end`` name → emits ``WsFloatingDomain`` * Streams text tokens → emits ``WsStreamText`` * Attaches mutations → injects into ``WsStreamEnd`` """ from __future__ import annotations import logging import re from collections.abc import AsyncGenerator from typing import Any from app.schemas import ( WsFloatingDomain, WsStreamBlock, WsStreamEnd, WsStreamStart, WsStreamText, ) logger = logging.getLogger(__name__) # Map sub-agent tool name → floating domain / entity type _AGENT_DOMAIN: dict[str, str] = { "task_agent": "tasks", "timeline_agent": "timelines", "note_agent": "notes", "project_agent": "projects", } # Regex for complete inline entity tags: [id1,id2] _ENTITY_TAG_RE = re.compile( r"<(task|project|note|timeline)>\[([^\]]+)\]" ) # Tag name → plural entity type for the WsStreamBlock data _TAG_ENTITY: dict[str, str] = { "task": "tasks", "project": "projects", "note": "notes", "timeline": "timelines", } WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain class HomeFormatter: """Consumes a deep-agent event stream and yields WS frames for the Home view. The supervisor's response contains inline entity tags like ``[abc-123]``. This formatter detects them, emits ``WsStreamBlock(block_type="entity_ref")`` with the entity type and IDs, and forwards surrounding text as ``WsStreamText``. Mutations are attached to ``WsStreamEnd``. """ def __init__(self, request_id: str) -> None: self.request_id = request_id self._mutations: list[dict] = [] self._buffer: str = "" def _flush_buffer(self, force: bool = False): """Extract complete entity tags and text from the buffer. Yields (frame_type, data) pairs: ("text", str) — plain text to send as WsStreamText ("block", dict) — entity_ref block to send as WsStreamBlock When *force* is True (end of stream), the entire buffer is flushed. Otherwise, text after the last unmatched ``<`` is held back in case it is the start of an entity tag arriving across token boundaries. """ buf = self._buffer while True: m = _ENTITY_TAG_RE.search(buf) if not m: break # Text before the tag before = buf[: m.start()] if before: yield ("text", before) # The entity tag itself → a block tag_type = m.group(1) raw_ids = m.group(2) ids = [i.strip() for i in raw_ids.split(",") if i.strip()] yield ( "block", { "entity": _TAG_ENTITY[tag_type], "ids": ids, }, ) buf = buf[m.end() :] if force: # End of stream — flush everything that remains if buf: yield ("text", buf) self._buffer = "" else: # Keep a potential partial tag (text after last '<') in the buffer last_lt = buf.rfind("<") if last_lt != -1: safe = buf[:last_lt] if safe: yield ("text", safe) self._buffer = buf[last_lt:] else: if buf: yield ("text", buf) self._buffer = "" async def format( self, event_stream: AsyncGenerator[tuple[str, Any], None], ) -> AsyncGenerator[WsFrame, None]: yield WsStreamStart(request_id=self.request_id) async for event_type, data in event_stream: if event_type == "token": if data: self._buffer += data for ftype, fdata in self._flush_buffer(): if ftype == "text": yield WsStreamText( request_id=self.request_id, chunk=fdata ) elif ftype == "block": yield WsStreamBlock( request_id=self.request_id, block_type="entity_ref", data=fdata, ) elif event_type == "mutations": self._mutations = data or [] # tool_end events are intentionally ignored — the supervisor # embeds relevant entity IDs inline via [ids] tags. # Flush any remaining buffer content for ftype, fdata in self._flush_buffer(force=True): if ftype == "text": yield WsStreamText(request_id=self.request_id, chunk=fdata) elif ftype == "block": yield WsStreamBlock( request_id=self.request_id, block_type="entity_ref", data=fdata, ) yield WsStreamEnd( request_id=self.request_id, mutations=[ {"action": m["action"], "table": m["table"], "data": m["data"]} for m in self._mutations ], ) class FloatingFormatter: """Consumes a deep-agent event stream and yields WS frames for the Floating view. Sniffs the first ``tool_end`` event name to derive the domain (e.g. ``task_agent`` → ``"tasks"``), then streams text tokens as plain ``WsStreamText``. No block parsing for floating context. """ def __init__(self, request_id: str) -> None: self.request_id = request_id self._mutations: list[dict] = [] async def format( self, event_stream: AsyncGenerator[tuple[str, Any], None], ) -> AsyncGenerator[WsFrame, None]: domain_sent = False async for event_type, data in event_stream: if event_type == "tool_end" and not domain_sent: # Sniff domain from the first sub-agent that completes name = data.get("name", "") domain = _AGENT_DOMAIN.get(name, "tasks") yield WsFloatingDomain( request_id=self.request_id, domain=domain, # type: ignore[arg-type] ) yield WsStreamStart(request_id=self.request_id) domain_sent = True elif event_type == "token": if not domain_sent: # First token arrived before any tool_end — default domain yield WsFloatingDomain( request_id=self.request_id, domain="tasks", # type: ignore[arg-type] ) yield WsStreamStart(request_id=self.request_id) domain_sent = True if data: yield WsStreamText(request_id=self.request_id, chunk=data) elif event_type == "mutations": self._mutations = data or [] # If no events triggered domain_sent (edge case), still emit structure if not domain_sent: yield WsFloatingDomain( request_id=self.request_id, domain="tasks", # type: ignore[arg-type] ) yield WsStreamStart(request_id=self.request_id) yield WsStreamEnd( request_id=self.request_id, mutations=[ {"action": m["action"], "table": m["table"], "data": m["data"]} for m in self._mutations ], )