The supervisor LLM now embeds <type>[id1,id2]</type> entity tags in its response text. The HomeFormatter buffers streamed tokens, detects complete tags across chunk boundaries, and emits WsStreamBlock with entity type + specific IDs. This replaces the old approach of emitting blocks for every tool_end event, which dumped ALL entities regardless of relevance. Also fixes: - NoneType guard on metadata in _run_graph_stream (metadata can be None) - Updated _HOME_SYSTEM prompt with entity tag instructions - Updated all affected tests
239 lines
8.2 KiB
Python
239 lines
8.2 KiB
Python
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
|
|
|
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
|
* ``("token", str)`` — supervisor text token
|
|
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
|
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
|
|
|
HomeFormatter:
|
|
* Buffers text tokens and parses inline entity tags
|
|
``<type>[id1,id2]</type>`` → emits ``WsStreamBlock`` (entity_ref with IDs)
|
|
* Streams surrounding text → emits ``WsStreamText``
|
|
* Attaches mutations → injects into ``WsStreamEnd``
|
|
|
|
FloatingFormatter:
|
|
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
|
* Streams text tokens → emits ``WsStreamText``
|
|
* Attaches mutations → injects into ``WsStreamEnd``
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
from app.schemas import (
|
|
WsFloatingDomain,
|
|
WsStreamBlock,
|
|
WsStreamEnd,
|
|
WsStreamStart,
|
|
WsStreamText,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Map sub-agent tool name → floating domain / entity type
|
|
_AGENT_DOMAIN: dict[str, str] = {
|
|
"task_agent": "tasks",
|
|
"timeline_agent": "timelines",
|
|
"note_agent": "notes",
|
|
"project_agent": "projects",
|
|
}
|
|
|
|
# Regex for complete inline entity tags: <task>[id1,id2]</task>
|
|
_ENTITY_TAG_RE = re.compile(
|
|
r"<(task|project|note|timeline)>\[([^\]]+)\]</\1>"
|
|
)
|
|
|
|
# Tag name → plural entity type for the WsStreamBlock data
|
|
_TAG_ENTITY: dict[str, str] = {
|
|
"task": "tasks",
|
|
"project": "projects",
|
|
"note": "notes",
|
|
"timeline": "timelines",
|
|
}
|
|
|
|
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
|
|
|
|
|
|
class HomeFormatter:
|
|
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
|
|
|
The supervisor's response contains inline entity tags like
|
|
``<project>[abc-123]</project>``. This formatter detects them,
|
|
emits ``WsStreamBlock(block_type="entity_ref")`` with the entity
|
|
type and IDs, and forwards surrounding text as ``WsStreamText``.
|
|
Mutations are attached to ``WsStreamEnd``.
|
|
"""
|
|
|
|
def __init__(self, request_id: str) -> None:
|
|
self.request_id = request_id
|
|
self._mutations: list[dict] = []
|
|
self._buffer: str = ""
|
|
|
|
def _flush_buffer(self, force: bool = False):
|
|
"""Extract complete entity tags and text from the buffer.
|
|
|
|
Yields (frame_type, data) pairs:
|
|
("text", str) — plain text to send as WsStreamText
|
|
("block", dict) — entity_ref block to send as WsStreamBlock
|
|
|
|
When *force* is True (end of stream), the entire buffer is flushed.
|
|
Otherwise, text after the last unmatched ``<`` is held back in case
|
|
it is the start of an entity tag arriving across token boundaries.
|
|
"""
|
|
buf = self._buffer
|
|
|
|
while True:
|
|
m = _ENTITY_TAG_RE.search(buf)
|
|
if not m:
|
|
break
|
|
|
|
# Text before the tag
|
|
before = buf[: m.start()]
|
|
if before:
|
|
yield ("text", before)
|
|
|
|
# The entity tag itself → a block
|
|
tag_type = m.group(1)
|
|
raw_ids = m.group(2)
|
|
ids = [i.strip() for i in raw_ids.split(",") if i.strip()]
|
|
yield (
|
|
"block",
|
|
{
|
|
"entity": _TAG_ENTITY[tag_type],
|
|
"ids": ids,
|
|
},
|
|
)
|
|
|
|
buf = buf[m.end() :]
|
|
|
|
if force:
|
|
# End of stream — flush everything that remains
|
|
if buf:
|
|
yield ("text", buf)
|
|
self._buffer = ""
|
|
else:
|
|
# Keep a potential partial tag (text after last '<') in the buffer
|
|
last_lt = buf.rfind("<")
|
|
if last_lt != -1:
|
|
safe = buf[:last_lt]
|
|
if safe:
|
|
yield ("text", safe)
|
|
self._buffer = buf[last_lt:]
|
|
else:
|
|
if buf:
|
|
yield ("text", buf)
|
|
self._buffer = ""
|
|
|
|
async def format(
|
|
self,
|
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
|
) -> AsyncGenerator[WsFrame, None]:
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
|
|
async for event_type, data in event_stream:
|
|
if event_type == "token":
|
|
if data:
|
|
self._buffer += data
|
|
for ftype, fdata in self._flush_buffer():
|
|
if ftype == "text":
|
|
yield WsStreamText(
|
|
request_id=self.request_id, chunk=fdata
|
|
)
|
|
elif ftype == "block":
|
|
yield WsStreamBlock(
|
|
request_id=self.request_id,
|
|
block_type="entity_ref",
|
|
data=fdata,
|
|
)
|
|
|
|
elif event_type == "mutations":
|
|
self._mutations = data or []
|
|
|
|
# tool_end events are intentionally ignored — the supervisor
|
|
# embeds relevant entity IDs inline via <type>[ids]</type> tags.
|
|
|
|
# Flush any remaining buffer content
|
|
for ftype, fdata in self._flush_buffer(force=True):
|
|
if ftype == "text":
|
|
yield WsStreamText(request_id=self.request_id, chunk=fdata)
|
|
elif ftype == "block":
|
|
yield WsStreamBlock(
|
|
request_id=self.request_id,
|
|
block_type="entity_ref",
|
|
data=fdata,
|
|
)
|
|
|
|
yield WsStreamEnd(
|
|
request_id=self.request_id,
|
|
mutations=[
|
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
|
for m in self._mutations
|
|
],
|
|
)
|
|
|
|
|
|
class FloatingFormatter:
|
|
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
|
|
|
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
|
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
|
``WsStreamText``. No block parsing for floating context.
|
|
"""
|
|
|
|
def __init__(self, request_id: str) -> None:
|
|
self.request_id = request_id
|
|
self._mutations: list[dict] = []
|
|
|
|
async def format(
|
|
self,
|
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
|
) -> AsyncGenerator[WsFrame, None]:
|
|
domain_sent = False
|
|
|
|
async for event_type, data in event_stream:
|
|
if event_type == "tool_end" and not domain_sent:
|
|
# Sniff domain from the first sub-agent that completes
|
|
name = data.get("name", "")
|
|
domain = _AGENT_DOMAIN.get(name, "tasks")
|
|
yield WsFloatingDomain(
|
|
request_id=self.request_id,
|
|
domain=domain, # type: ignore[arg-type]
|
|
)
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
domain_sent = True
|
|
|
|
elif event_type == "token":
|
|
if not domain_sent:
|
|
# First token arrived before any tool_end — default domain
|
|
yield WsFloatingDomain(
|
|
request_id=self.request_id,
|
|
domain="tasks", # type: ignore[arg-type]
|
|
)
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
domain_sent = True
|
|
if data:
|
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
|
|
|
elif event_type == "mutations":
|
|
self._mutations = data or []
|
|
|
|
# If no events triggered domain_sent (edge case), still emit structure
|
|
if not domain_sent:
|
|
yield WsFloatingDomain(
|
|
request_id=self.request_id,
|
|
domain="tasks", # type: ignore[arg-type]
|
|
)
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
|
|
yield WsStreamEnd(
|
|
request_id=self.request_id,
|
|
mutations=[
|
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
|
for m in self._mutations
|
|
],
|
|
)
|