83 lines
2.4 KiB
Python
83 lines
2.4 KiB
Python
"""Tests for app.core.output_formatter.StreamFormatter."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from app.core.output_formatter import StreamFormatter
|
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
|
|
|
|
|
async def _stream(*events: tuple[str, object]):
|
|
for event in events:
|
|
yield event
|
|
|
|
|
|
async def _collect(formatter: StreamFormatter, event_stream):
|
|
frames = []
|
|
async for frame in formatter.format(event_stream):
|
|
frames.append(frame)
|
|
return frames
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_formatter_text_stream() -> None:
|
|
formatter = StreamFormatter(request_id="req-1")
|
|
frames = await _collect(
|
|
formatter,
|
|
_stream(("token", "Hello"), ("token", " world")),
|
|
)
|
|
|
|
assert isinstance(frames[0], WsStreamStart)
|
|
assert isinstance(frames[1], WsStreamText)
|
|
assert frames[1].chunk == "Hello"
|
|
assert isinstance(frames[2], WsStreamText)
|
|
assert frames[2].chunk == " world"
|
|
assert isinstance(frames[-1], WsStreamEnd)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_formatter_floating_domain_first() -> None:
|
|
formatter = StreamFormatter(request_id="req-2")
|
|
frames = await _collect(
|
|
formatter,
|
|
_stream(
|
|
(
|
|
"floating_domain",
|
|
{"type": "node", "id": "n-1", "section": None},
|
|
),
|
|
("token", "Summary"),
|
|
),
|
|
)
|
|
|
|
assert isinstance(frames[0], WsFloatingDomain)
|
|
assert frames[0].domain.type == "node"
|
|
assert frames[0].domain.id == "n-1"
|
|
assert isinstance(frames[1], WsStreamStart)
|
|
assert isinstance(frames[2], WsStreamText)
|
|
assert frames[2].chunk == "Summary"
|
|
assert isinstance(frames[-1], WsStreamEnd)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_formatter_ignores_unknown_events() -> None:
|
|
formatter = StreamFormatter(request_id="req-3")
|
|
frames = await _collect(
|
|
formatter,
|
|
_stream(("tool_end", {"name": "x"}), ("token", "ok")),
|
|
)
|
|
|
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
|
assert len(text_frames) == 1
|
|
assert text_frames[0].chunk == "ok"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_formatter_empty_stream_still_brackets() -> None:
|
|
formatter = StreamFormatter(request_id="req-4")
|
|
frames = await _collect(formatter, _stream())
|
|
|
|
assert len(frames) == 2
|
|
assert isinstance(frames[0], WsStreamStart)
|
|
assert isinstance(frames[1], WsStreamEnd)
|