48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
"""Output formatter for deep-agent stream events."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
|
|
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
|
|
|
|
|
class StreamFormatter:
|
|
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
|
|
|
def __init__(self, request_id: str) -> None:
|
|
self.request_id = request_id
|
|
|
|
async def format(
|
|
self,
|
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
|
) -> AsyncGenerator[WsFrame, None]:
|
|
started = False
|
|
|
|
async for event_type, data in event_stream:
|
|
if event_type == "floating_domain":
|
|
if isinstance(data, dict):
|
|
yield WsFloatingDomain(
|
|
request_id=self.request_id,
|
|
domain=data,
|
|
)
|
|
continue
|
|
|
|
if event_type != "token":
|
|
continue
|
|
|
|
if not started:
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
started = True
|
|
|
|
text = str(data or "")
|
|
if text:
|
|
yield WsStreamText(request_id=self.request_id, chunk=text)
|
|
|
|
if not started:
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
yield WsStreamEnd(request_id=self.request_id)
|