142 lines
4.9 KiB
Python
142 lines
4.9 KiB
Python
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
|
|
|
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
|
* ``("token", str)`` — supervisor text token
|
|
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
|
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
|
|
|
HomeFormatter:
|
|
* Streams text tokens as-is → emits ``WsStreamText``
|
|
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
|
for the frontend to parse and render as interactive components)
|
|
* Attaches mutations → injects into ``WsStreamEnd``
|
|
|
|
FloatingFormatter:
|
|
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
|
* Streams text tokens → emits ``WsStreamText``
|
|
* Attaches mutations → injects into ``WsStreamEnd``
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
from app.schemas import (
|
|
WsFloatingDomain,
|
|
WsStreamEnd,
|
|
WsStreamStart,
|
|
WsStreamText,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Map sub-agent tool name → floating domain / entity type
|
|
_AGENT_DOMAIN: dict[str, str] = {
|
|
"task_agent": "tasks",
|
|
"timeline_agent": "timelines",
|
|
"note_agent": "notes",
|
|
"project_agent": "projects",
|
|
}
|
|
|
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
|
|
|
|
|
class HomeFormatter:
|
|
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
|
|
|
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
|
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
|
is responsible for parsing those and rendering interactive components.
|
|
Mutations are attached to ``WsStreamEnd``.
|
|
"""
|
|
|
|
def __init__(self, request_id: str) -> None:
|
|
self.request_id = request_id
|
|
self._mutations: list[dict] = []
|
|
|
|
async def format(
|
|
self,
|
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
|
) -> AsyncGenerator[WsFrame, None]:
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
|
|
async for event_type, data in event_stream:
|
|
if event_type == "token":
|
|
if data:
|
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
|
|
|
elif event_type == "mutations":
|
|
self._mutations = data or []
|
|
|
|
yield WsStreamEnd(
|
|
request_id=self.request_id,
|
|
mutations=[
|
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
|
for m in self._mutations
|
|
],
|
|
)
|
|
|
|
|
|
class FloatingFormatter:
|
|
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
|
|
|
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
|
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
|
``WsStreamText``. No block parsing for floating context.
|
|
"""
|
|
|
|
def __init__(self, request_id: str) -> None:
|
|
self.request_id = request_id
|
|
self._mutations: list[dict] = []
|
|
|
|
async def format(
|
|
self,
|
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
|
) -> AsyncGenerator[WsFrame, None]:
|
|
domain_sent = False
|
|
|
|
async for event_type, data in event_stream:
|
|
if event_type == "tool_end" and not domain_sent:
|
|
# Sniff domain from the first sub-agent that completes
|
|
name = data.get("name", "")
|
|
domain = _AGENT_DOMAIN.get(name, "tasks")
|
|
yield WsFloatingDomain(
|
|
request_id=self.request_id,
|
|
domain=domain, # type: ignore[arg-type]
|
|
)
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
domain_sent = True
|
|
|
|
elif event_type == "token":
|
|
if not domain_sent:
|
|
# First token arrived before any tool_end — default domain
|
|
yield WsFloatingDomain(
|
|
request_id=self.request_id,
|
|
domain="tasks", # type: ignore[arg-type]
|
|
)
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
domain_sent = True
|
|
if data:
|
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
|
|
|
elif event_type == "mutations":
|
|
self._mutations = data or []
|
|
|
|
# If no events triggered domain_sent (edge case), still emit structure
|
|
if not domain_sent:
|
|
yield WsFloatingDomain(
|
|
request_id=self.request_id,
|
|
domain="tasks", # type: ignore[arg-type]
|
|
)
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
|
|
yield WsStreamEnd(
|
|
request_id=self.request_id,
|
|
mutations=[
|
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
|
for m in self._mutations
|
|
],
|
|
)
|