"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter.""" from __future__ import annotations import pytest from app.core.output_formatter import HomeFormatter, FloatingFormatter from app.schemas import ( WsFloatingDomain, WsStreamBlock, WsStreamEnd, WsStreamStart, WsStreamText, ) # ── helpers ─────────────────────────────────────────────────────────────────── async def _stream(*events: tuple[str, object]): """Async generator that yields (event_type, data) tuples.""" for event in events: yield event async def collect(formatter, event_stream): frames = [] async for frame in formatter.format(event_stream): frames.append(frame) return frames # ── HomeFormatter ───────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_home_formatter_plain_text(): req_id = "req-1" events = [ ("token", "Hello world"), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsStreamStart) assert frames[0].request_id == req_id text_frames = [f for f in frames if isinstance(f, WsStreamText)] assert any("Hello world" in f.chunk for f in text_frames) assert isinstance(frames[-1], WsStreamEnd) @pytest.mark.asyncio async def test_home_formatter_entity_tag_single_id(): """A [id] tag emits a WsStreamBlock with entity + ids.""" req_id = "req-2" events = [ ("token", "Here is your project:\n[abc-123]\nAll good."), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 assert block_frames[0].block_type == "entity_ref" assert block_frames[0].data["entity"] == "projects" assert block_frames[0].data["ids"] == ["abc-123"] # Surrounding text is streamed text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText)) assert "Here is your project:" in text assert "All good." in text # The raw tag itself should NOT appear in streamed text assert "" not in text @pytest.mark.asyncio async def test_home_formatter_entity_tag_multiple_ids(): req_id = "req-3" events = [ ("token", "Pending:\n[id-1,id-2,id-3]"), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 assert block_frames[0].data["entity"] == "tasks" assert block_frames[0].data["ids"] == ["id-1", "id-2", "id-3"] @pytest.mark.asyncio async def test_home_formatter_multiple_entity_tags(): req_id = "req-4" events = [ ("token", "[p1]\nText\n[t1,t2]"), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 2 entities = {b.data["entity"] for b in block_frames} assert entities == {"projects", "tasks"} @pytest.mark.asyncio async def test_home_formatter_tag_split_across_tokens(): """Entity tag arrives across two token chunks — still detected.""" req_id = "req-5" events = [ ("token", "See: [abc-"), ("token", "123] done"), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 assert block_frames[0].data["entity"] == "projects" assert block_frames[0].data["ids"] == ["abc-123"] @pytest.mark.asyncio async def test_home_formatter_tool_end_ignored(): """tool_end events no longer produce blocks — only entity tags do.""" req_id = "req-6" events = [ ("tool_end", {"name": "task_agent", "result": "3 tasks"}), ("token", "No tags here."), ("mutations", []), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 0 @pytest.mark.asyncio async def test_home_formatter_mutations_in_stream_end(): req_id = "req-7" muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}] events = [ ("token", "Done"), ("mutations", muts), ] formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) end_frame = frames[-1] assert isinstance(end_frame, WsStreamEnd) assert len(end_frame.mutations) == 1 assert end_frame.mutations[0]["action"] == "insert" @pytest.mark.asyncio async def test_home_formatter_frame_order(): """stream_start is first, stream_end is last.""" req_id = "req-8" formatter = HomeFormatter(request_id=req_id) frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", []))) assert isinstance(frames[0], WsStreamStart) assert isinstance(frames[-1], WsStreamEnd) # ── FloatingFormatter ───────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_floating_formatter_domain_from_tool_end(): req_id = "pop-1" formatter = FloatingFormatter(request_id=req_id) events = [ ("tool_end", {"name": "task_agent", "result": "ok"}), ("token", "Hello"), ("mutations", []), ] frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsFloatingDomain) assert frames[0].domain == "tasks" assert frames[0].request_id == req_id @pytest.mark.asyncio async def test_floating_formatter_text_only(): req_id = "pop-2" formatter = FloatingFormatter(request_id=req_id) events = [ ("tool_end", {"name": "timeline_agent", "result": "done"}), ("token", "Summary"), ("mutations", []), ] frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsFloatingDomain) assert frames[0].domain == "timelines" text_frames = [f for f in frames if isinstance(f, WsStreamText)] assert len(text_frames) == 1 assert text_frames[0].chunk == "Summary" @pytest.mark.asyncio async def test_floating_formatter_no_block_frames(): """FloatingFormatter must never emit WsStreamBlock.""" req_id = "pop-3" formatter = FloatingFormatter(request_id=req_id) events = [ ("tool_end", {"name": "note_agent", "result": "data"}), ("token", "some text"), ("mutations", []), ] frames = await collect(formatter, _stream(*events)) assert not any(isinstance(f, WsStreamBlock) for f in frames) @pytest.mark.asyncio async def test_floating_formatter_end_frame(): req_id = "pop-4" formatter = FloatingFormatter(request_id=req_id) events = [ ("tool_end", {"name": "project_agent", "result": "ok"}), ("token", "Done"), ("mutations", []), ] frames = await collect(formatter, _stream(*events)) assert isinstance(frames[-1], WsStreamEnd) @pytest.mark.asyncio async def test_floating_formatter_default_domain_on_early_token(): """When the first event is a token (no tool_end yet), default to 'tasks'.""" req_id = "pop-5" formatter = FloatingFormatter(request_id=req_id) events = [("token", "hi"), ("mutations", [])] frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsFloatingDomain) assert frames[0].domain == "tasks" @pytest.mark.asyncio async def test_floating_formatter_mutations_in_stream_end(): req_id = "pop-6" muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}] events = [ ("token", "Updated"), ("mutations", muts), ] formatter = FloatingFormatter(request_id=req_id) frames = await collect(formatter, _stream(*events)) end_frame = frames[-1] assert isinstance(end_frame, WsStreamEnd) assert len(end_frame.mutations) == 1