44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
"""Output formatter for deep-agent stream events."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
|
|
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
|
|
|
|
|
class StreamFormatter:
|
|
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
|
|
|
def __init__(self, request_id: str) -> None:
|
|
self.request_id = request_id
|
|
|
|
async def format(
|
|
self,
|
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
|
) -> AsyncGenerator[WsFrame, None]:
|
|
started = False
|
|
|
|
async for event_type, data in event_stream:
|
|
if event_type == "floating_domain":
|
|
yield WsFloatingDomain(request_id=self.request_id, domain=str(data))
|
|
continue
|
|
|
|
if event_type != "token":
|
|
continue
|
|
|
|
if not started:
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
started = True
|
|
|
|
text = str(data or "")
|
|
if text:
|
|
yield WsStreamText(request_id=self.request_id, chunk=text)
|
|
|
|
if not started:
|
|
yield WsStreamStart(request_id=self.request_id)
|
|
yield WsStreamEnd(request_id=self.request_id)
|