215 lines
7.1 KiB
Python
215 lines
7.1 KiB
Python
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
|
from app.schemas import (
|
|
WsFloatingDomain,
|
|
WsStreamEnd,
|
|
WsStreamStart,
|
|
WsStreamText,
|
|
)
|
|
|
|
|
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
async def _stream(*events: tuple[str, object]):
|
|
"""Async generator that yields (event_type, data) tuples."""
|
|
for event in events:
|
|
yield event
|
|
|
|
|
|
async def collect(formatter, event_stream):
|
|
frames = []
|
|
async for frame in formatter.format(event_stream):
|
|
frames.append(frame)
|
|
return frames
|
|
|
|
|
|
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_plain_text():
|
|
req_id = "req-1"
|
|
events = [
|
|
("token", "Hello world"),
|
|
("mutations", []),
|
|
]
|
|
formatter = HomeFormatter(request_id=req_id)
|
|
frames = await collect(formatter, _stream(*events))
|
|
|
|
assert isinstance(frames[0], WsStreamStart)
|
|
assert frames[0].request_id == req_id
|
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
|
assert any("Hello world" in f.chunk for f in text_frames)
|
|
assert isinstance(frames[-1], WsStreamEnd)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_entity_tags_passed_through():
|
|
"""Entity tags are streamed as-is — the frontend parses them."""
|
|
req_id = "req-2"
|
|
events = [
|
|
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
|
("mutations", []),
|
|
]
|
|
formatter = HomeFormatter(request_id=req_id)
|
|
frames = await collect(formatter, _stream(*events))
|
|
|
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
|
assert "<project>[abc-123]</project>" in text
|
|
assert "Here is your project:" in text
|
|
assert "All good." in text
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_multiple_tags_passed_through():
|
|
req_id = "req-3"
|
|
events = [
|
|
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
|
("mutations", []),
|
|
]
|
|
formatter = HomeFormatter(request_id=req_id)
|
|
frames = await collect(formatter, _stream(*events))
|
|
|
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
|
assert "<project>[p1]</project>" in text
|
|
assert "<task>[t1,t2]</task>" in text
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_tool_end_ignored():
|
|
"""tool_end events are silently ignored by HomeFormatter."""
|
|
req_id = "req-4"
|
|
events = [
|
|
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
|
("token", "No tags here."),
|
|
("mutations", []),
|
|
]
|
|
formatter = HomeFormatter(request_id=req_id)
|
|
frames = await collect(formatter, _stream(*events))
|
|
|
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
|
assert text == "No tags here."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_mutations_in_stream_end():
|
|
req_id = "req-5"
|
|
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
|
events = [
|
|
("token", "Done"),
|
|
("mutations", muts),
|
|
]
|
|
formatter = HomeFormatter(request_id=req_id)
|
|
frames = await collect(formatter, _stream(*events))
|
|
|
|
end_frame = frames[-1]
|
|
assert isinstance(end_frame, WsStreamEnd)
|
|
assert len(end_frame.mutations) == 1
|
|
assert end_frame.mutations[0]["action"] == "insert"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_home_formatter_frame_order():
|
|
"""stream_start is first, stream_end is last."""
|
|
req_id = "req-6"
|
|
formatter = HomeFormatter(request_id=req_id)
|
|
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
|
assert isinstance(frames[0], WsStreamStart)
|
|
assert isinstance(frames[-1], WsStreamEnd)
|
|
|
|
|
|
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_domain_from_tool_end():
|
|
req_id = "pop-1"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
events = [
|
|
("tool_end", {"name": "task_agent", "result": "ok"}),
|
|
("token", "Hello"),
|
|
("mutations", []),
|
|
]
|
|
frames = await collect(formatter, _stream(*events))
|
|
|
|
assert isinstance(frames[0], WsFloatingDomain)
|
|
assert frames[0].domain == "tasks"
|
|
assert frames[0].request_id == req_id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_text_only():
|
|
req_id = "pop-2"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
events = [
|
|
("tool_end", {"name": "timeline_agent", "result": "done"}),
|
|
("token", "Summary"),
|
|
("mutations", []),
|
|
]
|
|
frames = await collect(formatter, _stream(*events))
|
|
|
|
assert isinstance(frames[0], WsFloatingDomain)
|
|
assert frames[0].domain == "timelines"
|
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
|
assert len(text_frames) == 1
|
|
assert text_frames[0].chunk == "Summary"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_no_entity_tags():
|
|
"""FloatingFormatter never emits entity tag blocks."""
|
|
req_id = "pop-3"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
events = [
|
|
("tool_end", {"name": "note_agent", "result": "data"}),
|
|
("token", "some text"),
|
|
("mutations", []),
|
|
]
|
|
frames = await collect(formatter, _stream(*events))
|
|
# Only expected frame types
|
|
for f in frames:
|
|
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_end_frame():
|
|
req_id = "pop-4"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
events = [
|
|
("tool_end", {"name": "project_agent", "result": "ok"}),
|
|
("token", "Done"),
|
|
("mutations", []),
|
|
]
|
|
frames = await collect(formatter, _stream(*events))
|
|
assert isinstance(frames[-1], WsStreamEnd)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_default_domain_on_early_token():
|
|
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
|
|
req_id = "pop-5"
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
events = [("token", "hi"), ("mutations", [])]
|
|
frames = await collect(formatter, _stream(*events))
|
|
assert isinstance(frames[0], WsFloatingDomain)
|
|
assert frames[0].domain == "tasks"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_floating_formatter_mutations_in_stream_end():
|
|
req_id = "pop-6"
|
|
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
|
|
events = [
|
|
("token", "Updated"),
|
|
("mutations", muts),
|
|
]
|
|
formatter = FloatingFormatter(request_id=req_id)
|
|
frames = await collect(formatter, _stream(*events))
|
|
|
|
end_frame = frames[-1]
|
|
assert isinstance(end_frame, WsStreamEnd)
|
|
assert len(end_frame.mutations) == 1
|