"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter.""" from __future__ import annotations import pytest from app.core.output_formatter import HomeFormatter, FloatingFormatter from app.schemas import ( WsFloatingDomain, WsStreamBlock, WsStreamEnd, WsStreamStart, WsStreamText, ) # ── helpers ─────────────────────────────────────────────────────────────────── async def _stream(*events: tuple[str, object]): """Async generator that yields (event_type, data) tuples.""" for event in events: yield event async def collect(formatter, event_stream): frames = [] async for frame in formatter.format(event_stream): frames.append(frame) return frames # ── HomeFormatter ───────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_home_formatter_text_token(): req_id = "req-1" events = [ ("token", "Hello world"), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsStreamStart) assert frames[0].request_id == req_id text_frames = [f for f in frames if isinstance(f, WsStreamText)] assert any("Hello world" in f.chunk for f in text_frames) assert isinstance(frames[-1], WsStreamEnd) @pytest.mark.asyncio async def test_home_formatter_entity_ref_from_tool_end(): req_id = "req-2" events = [ ("tool_end", {"name": "task_agent", "result": "Found 3 tasks."}), ("token", "Here are your tasks."), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 assert block_frames[0].block_type == "entity_ref" assert block_frames[0].data["entity"] == "tasks" assert block_frames[0].data["result"] == "Found 3 tasks." @pytest.mark.asyncio async def test_home_formatter_unknown_agent_no_block(): req_id = "req-3" events = [ ("tool_end", {"name": "unknown_agent", "result": "stuff"}), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 0 # unknown agent → no entity mapping @pytest.mark.asyncio async def test_home_formatter_mutations_in_stream_end(): req_id = "req-4" muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}] events = [ ("token", "Done"), ("mutations", muts), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) end_frame = frames[-1] assert isinstance(end_frame, WsStreamEnd) assert len(end_frame.mutations) == 1 assert end_frame.mutations[0]["action"] == "insert" @pytest.mark.asyncio async def test_home_formatter_frame_order(): """stream_start is first, stream_end is last.""" req_id = "req-5" formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", []))) assert isinstance(frames[0], WsStreamStart) assert isinstance(frames[-1], WsStreamEnd) @pytest.mark.asyncio async def test_home_formatter_multiple_tool_ends(): req_id = "req-6" events = [ ("tool_end", {"name": "task_agent", "result": "3 tasks"}), ("tool_end", {"name": "project_agent", "result": "2 projects"}), ("token", "Overview done."), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 2 entities = {b.data["entity"] for b in block_frames} assert entities == {"tasks", "projects"} # ── FloatingFormatter ───────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_floating_formatter_domain_from_tool_end(): req_id = "pop-1" formatter = FloatingFormatter(request_id=req_id) events = [ ("tool_end", {"name": "task_agent", "result": "ok"}), ("token", "Hello"), ("mutations", []), ] frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsFloatingDomain) assert frames[0].domain == "tasks" assert frames[0].request_id == req_id @pytest.mark.asyncio async def test_floating_formatter_text_only(): req_id = "pop-2" formatter = FloatingFormatter(request_id=req_id) events = [ ("tool_end", {"name": "timeline_agent", "result": "done"}), ("token", "Summary"), ("mutations", []), ] frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsFloatingDomain) assert frames[0].domain == "timelines" text_frames = [f for f in frames if isinstance(f, WsStreamText)] assert len(text_frames) == 1 assert text_frames[0].chunk == "Summary" @pytest.mark.asyncio async def test_floating_formatter_no_block_frames(): """FloatingFormatter must never emit WsStreamBlock.""" req_id = "pop-3" formatter = FloatingFormatter(request_id=req_id) events = [ ("tool_end", {"name": "note_agent", "result": "data"}), ("token", "some text"), ("mutations", []), ] frames = await collect(formatter, _stream(*events)) assert not any(isinstance(f, WsStreamBlock) for f in frames) @pytest.mark.asyncio async def test_floating_formatter_end_frame(): req_id = "pop-4" formatter = FloatingFormatter(request_id=req_id) events = [ ("tool_end", {"name": "project_agent", "result": "ok"}), ("token", "Done"), ("mutations", []), ] frames = await collect(formatter, _stream(*events)) assert isinstance(frames[-1], WsStreamEnd) @pytest.mark.asyncio async def test_floating_formatter_default_domain_on_early_token(): """When the first event is a token (no tool_end yet), default to 'tasks'.""" req_id = "pop-5" formatter = FloatingFormatter(request_id=req_id) events = [("token", "hi"), ("mutations", [])] frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsFloatingDomain) assert frames[0].domain == "tasks" @pytest.mark.asyncio async def test_floating_formatter_mutations_in_stream_end(): req_id = "pop-6" muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}] events = [ ("token", "Updated"), ("mutations", muts), ] formatter = FloatingFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) end_frame = frames[-1] assert isinstance(end_frame, WsStreamEnd) assert len(end_frame.mutations) == 1