196 lines
7.2 KiB
Python
196 lines
7.2 KiB
Python
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
|
from app.schemas import (
|
|
WsFloatingDomain,
|
|
WsStreamBlock,
|
|
WsStreamEnd,
|
|
WsStreamStart,
|
|
WsStreamText,
|
|
)
|
|
|
|
|
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
async def _stream(*pairs: tuple[str, str]):
|
|
"""Async generator that yields (agent_name, token) pairs."""
|
|
for pair in pairs:
|
|
yield pair
|
|
|
|
|
|
async def collect(formatter, token_stream):
|
|
frames = []
|
|
async for frame in formatter.format(token_stream):
|
|
frames.append(frame)
|
|
return frames
|
|
|
|
|
|
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_text_block():
|
|
req_id = "req-1"
|
|
tokens = [
|
|
("task_agent", '{"type": "text", "content": "Hello world"}'),
|
|
]
|
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
frames = await collect(formatter, _stream(*tokens))
|
|
|
|
assert isinstance(frames[0], WsStreamStart)
|
|
assert frames[0].request_id == req_id
|
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
|
assert any("Hello world" in f.chunk for f in text_frames)
|
|
assert isinstance(frames[-1], WsStreamEnd)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_chart_block():
|
|
req_id = "req-2"
|
|
chart_json = (
|
|
'{"type": "chart", "chartType": "bar", '
|
|
'"title": "Tasks", "data": [{"x": 1}], '
|
|
'"config": {"x": {"label": "X", "color": "#fff"}}}'
|
|
)
|
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
frames = await collect(formatter, _stream(("task_agent", chart_json)))
|
|
|
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
assert len(block_frames) == 1
|
|
assert block_frames[0].block_type == "chart"
|
|
assert block_frames[0].data["chartType"] == "bar"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_invalid_chart_skipped():
|
|
req_id = "req-3"
|
|
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
|
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
frames = await collect(formatter, _stream(("task_agent", bad_chart)))
|
|
|
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
assert len(block_frames) == 0 # invalid chart skipped
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_entity_ref_resolved():
|
|
req_id = "req-4"
|
|
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
|
|
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
|
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
|
|
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
|
|
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
assert len(block_frames) == 1
|
|
assert block_frames[0].data["entity"] == "task"
|
|
assert block_frames[0].data["items"][0]["id"] == "t1"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_entity_ref_missing_skipped():
|
|
req_id = "req-5"
|
|
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
|
|
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
assert len(block_frames) == 0 # no tool results → skipped
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_table_block():
|
|
req_id = "req-6"
|
|
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
|
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
frames = await collect(formatter, _stream(("task_agent", table_json)))
|
|
|
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
assert len(block_frames) == 1
|
|
assert block_frames[0].block_type == "table"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_timeline_block():
|
|
req_id = "req-7"
|
|
timeline_json = '{"type": "timeline", "checkpoints": [{"id": "c1", "title": "M1", "date": 123}]}'
|
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
|
|
|
|
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
assert len(block_frames) == 1
|
|
assert block_frames[0].block_type == "timeline"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_frame_order():
|
|
"""stream_start is first, stream_end is last."""
|
|
req_id = "req-8"
|
|
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
|
|
assert isinstance(frames[0], WsStreamStart)
|
|
assert isinstance(frames[-1], WsStreamEnd)
|
|
|
|
|
|
# ── FloatingFormatter ────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_domain_emitted_first():
|
|
req_id = "pop-1"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
tokens = [
|
|
("task_agent", ""), # domain signal
|
|
("task_agent", "Hello"),
|
|
("task_agent", " there"),
|
|
]
|
|
frames = await collect(formatter, _stream(*tokens))
|
|
|
|
assert isinstance(frames[0], WsFloatingDomain)
|
|
assert frames[0].domain == "tasks"
|
|
assert frames[0].request_id == req_id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_text_only():
|
|
req_id = "pop-2"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")]
|
|
frames = await collect(formatter, _stream(*tokens))
|
|
|
|
assert isinstance(frames[0], WsFloatingDomain)
|
|
assert frames[0].domain == "checkpoints"
|
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
|
assert len(text_frames) == 1
|
|
assert text_frames[0].chunk == "Summary"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_no_block_frames():
|
|
"""FloatingFormatter must never emit WsStreamBlock."""
|
|
req_id = "pop-3"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
tokens = [
|
|
("note_agent", ""),
|
|
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
|
|
]
|
|
frames = await collect(formatter, _stream(*tokens))
|
|
assert not any(isinstance(f, WsStreamBlock) for f in frames)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_end_frame():
|
|
req_id = "pop-4"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
|
|
assert isinstance(frames[-1], WsStreamEnd)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_unknown_agent_defaults_to_tasks():
|
|
req_id = "pop-5"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
|
|
assert frames[0].domain == "tasks"
|