"""Tests for app.core.output_formatter — HomeFormatter and PopupFormatter.""" from __future__ import annotations import pytest from app.core.output_formatter import HomeFormatter, PopupFormatter from app.schemas import ( WsPopupDomain, WsStreamBlock, WsStreamEnd, WsStreamStart, WsStreamText, ) # ── helpers ─────────────────────────────────────────────────────────────────── async def _stream(*pairs: tuple[str, str]): """Async generator that yields (agent_name, token) pairs.""" for pair in pairs: yield pair async def collect(formatter, token_stream): frames = [] async for frame in formatter.format(token_stream): frames.append(frame) return frames # ── HomeFormatter ───────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_home_formatter_text_block(): req_id = "req-1" tokens = [ ("task_agent", '{"type": "text", "content": "Hello world"}'), ] formatter = HomeFormatter(request_id=req_id, tool_results=[]) frames = await collect(formatter, _stream(*tokens)) assert isinstance(frames[0], WsStreamStart) assert frames[0].request_id == req_id text_frames = [f for f in frames if isinstance(f, WsStreamText)] assert any("Hello world" in f.chunk for f in text_frames) assert isinstance(frames[-1], WsStreamEnd) @pytest.mark.asyncio async def test_home_formatter_chart_block(): req_id = "req-2" chart_json = ( '{"type": "chart", "chartType": "bar", ' '"title": "Tasks", "data": [{"x": 1}], ' '"config": {"x": {"label": "X", "color": "#fff"}}}' ) formatter = HomeFormatter(request_id=req_id, tool_results=[]) frames = await collect(formatter, _stream(("task_agent", chart_json))) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 assert block_frames[0].block_type == "chart" assert block_frames[0].data["chartType"] == "bar" @pytest.mark.asyncio async def test_home_formatter_invalid_chart_skipped(): req_id = "req-3" bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}' formatter = HomeFormatter(request_id=req_id, tool_results=[]) frames = await collect(formatter, _stream(("task_agent", bad_chart))) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 0 # invalid chart skipped @pytest.mark.asyncio async def test_home_formatter_entity_ref_resolved(): req_id = "req-4" tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}] entity_json = '{"type": "entity_ref", "entity": "task"}' formatter = HomeFormatter(request_id=req_id, tool_results=tool_results) frames = await collect(formatter, _stream(("task_agent", entity_json))) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 assert block_frames[0].data["entity"] == "task" assert block_frames[0].data["items"][0]["id"] == "t1" @pytest.mark.asyncio async def test_home_formatter_entity_ref_missing_skipped(): req_id = "req-5" entity_json = '{"type": "entity_ref", "entity": "task"}' formatter = HomeFormatter(request_id=req_id, tool_results=[]) frames = await collect(formatter, _stream(("task_agent", entity_json))) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 0 # no tool results → skipped @pytest.mark.asyncio async def test_home_formatter_table_block(): req_id = "req-6" table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}' formatter = HomeFormatter(request_id=req_id, tool_results=[]) frames = await collect(formatter, _stream(("task_agent", table_json))) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 assert block_frames[0].block_type == "table" @pytest.mark.asyncio async def test_home_formatter_timeline_block(): req_id = "req-7" timeline_json = '{"type": "timeline", "checkpoints": [{"id": "c1", "title": "M1", "date": 123}]}' formatter = HomeFormatter(request_id=req_id, tool_results=[]) frames = await collect(formatter, _stream(("task_agent", timeline_json))) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 assert block_frames[0].block_type == "timeline" @pytest.mark.asyncio async def test_home_formatter_frame_order(): """stream_start is first, stream_end is last.""" req_id = "req-8" formatter = HomeFormatter(request_id=req_id, tool_results=[]) frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}'))) assert isinstance(frames[0], WsStreamStart) assert isinstance(frames[-1], WsStreamEnd) # ── PopupFormatter ──────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_popup_formatter_domain_emitted_first(): req_id = "pop-1" formatter = PopupFormatter(request_id=req_id) tokens = [ ("task_agent", ""), # domain signal ("task_agent", "Hello"), ("task_agent", " there"), ] frames = await collect(formatter, _stream(*tokens)) assert isinstance(frames[0], WsPopupDomain) assert frames[0].domain == "tasks" assert frames[0].request_id == req_id @pytest.mark.asyncio async def test_popup_formatter_text_only(): req_id = "pop-2" formatter = PopupFormatter(request_id=req_id) tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")] frames = await collect(formatter, _stream(*tokens)) assert isinstance(frames[0], WsPopupDomain) assert frames[0].domain == "checkpoints" text_frames = [f for f in frames if isinstance(f, WsStreamText)] assert len(text_frames) == 1 assert text_frames[0].chunk == "Summary" @pytest.mark.asyncio async def test_popup_formatter_no_block_frames(): """PopupFormatter must never emit WsStreamBlock.""" req_id = "pop-3" formatter = PopupFormatter(request_id=req_id) tokens = [ ("note_agent", ""), ("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'), ] frames = await collect(formatter, _stream(*tokens)) assert not any(isinstance(f, WsStreamBlock) for f in frames) @pytest.mark.asyncio async def test_popup_formatter_end_frame(): req_id = "pop-4" formatter = PopupFormatter(request_id=req_id) frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done"))) assert isinstance(frames[-1], WsStreamEnd) @pytest.mark.asyncio async def test_popup_formatter_unknown_agent_defaults_to_tasks(): req_id = "pop-5" formatter = PopupFormatter(request_id=req_id) frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi"))) assert frames[0].domain == "tasks"