293 lines
9.1 KiB
Python
293 lines
9.1 KiB
Python
"""Tests for v3 WebSocket frame protocol schemas."""
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from app.schemas import (
|
|
WsFrameType,
|
|
WsHomeRequest,
|
|
WsPopupDomain,
|
|
WsPopupRequest,
|
|
WsPopupScope,
|
|
WsStreamBlock,
|
|
WsStreamEnd,
|
|
WsStreamStart,
|
|
WsStreamText,
|
|
)
|
|
|
|
|
|
# ── WsFrameType ───────────────────────────────────────────────────────
|
|
|
|
|
|
def test_v3_frame_types_exist():
|
|
v3_types = [
|
|
"home_request",
|
|
"popup_request",
|
|
"stream_start",
|
|
"stream_text",
|
|
"stream_block",
|
|
"stream_end",
|
|
"popup_domain",
|
|
"data_request",
|
|
"data_response",
|
|
"mutation",
|
|
]
|
|
for name in v3_types:
|
|
assert hasattr(WsFrameType, name), f"WsFrameType missing: {name}"
|
|
assert WsFrameType[name].value == name
|
|
|
|
|
|
def test_v2_frame_types_still_exist():
|
|
"""Backward compat: v2 types must remain."""
|
|
v2_types = [
|
|
"chat_request",
|
|
"text_chunk",
|
|
"tool_call",
|
|
"tool_result",
|
|
"final",
|
|
"ping",
|
|
"agent_run",
|
|
"agent_data",
|
|
"agent_complete",
|
|
"device_hello",
|
|
]
|
|
for name in v2_types:
|
|
assert hasattr(WsFrameType, name), f"v2 WsFrameType missing: {name}"
|
|
|
|
|
|
# ── WsHomeRequest ─────────────────────────────────────────────────────
|
|
|
|
|
|
def test_home_request_defaults():
|
|
frame = WsHomeRequest(message="Hello")
|
|
assert frame.type == WsFrameType.home_request
|
|
assert frame.message == "Hello"
|
|
assert frame.conversation_history == []
|
|
|
|
|
|
def test_home_request_with_history():
|
|
history = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]
|
|
frame = WsHomeRequest(message="Follow up", conversation_history=history)
|
|
assert frame.conversation_history == history
|
|
|
|
|
|
def test_home_request_serializes():
|
|
frame = WsHomeRequest(message="Test")
|
|
data = frame.model_dump()
|
|
assert data["type"] == "home_request"
|
|
assert data["message"] == "Test"
|
|
assert data["conversation_history"] == []
|
|
|
|
|
|
def test_home_request_deserializes():
|
|
raw = {"type": "home_request", "message": "Hi there"}
|
|
frame = WsHomeRequest.model_validate(raw)
|
|
assert frame.message == "Hi there"
|
|
|
|
|
|
def test_home_request_requires_message():
|
|
with pytest.raises(ValidationError):
|
|
WsHomeRequest.model_validate({"type": "home_request"})
|
|
|
|
|
|
# ── WsPopupRequest ────────────────────────────────────────────────────
|
|
|
|
|
|
def test_popup_request_basic():
|
|
frame = WsPopupRequest(
|
|
message="Summarise",
|
|
scope=WsPopupScope(type="task", id="task-123"),
|
|
)
|
|
assert frame.type == WsFrameType.popup_request
|
|
assert frame.scope.type == "task"
|
|
assert frame.scope.id == "task-123"
|
|
|
|
|
|
def test_popup_request_scope_without_id():
|
|
frame = WsPopupRequest(
|
|
message="Show all",
|
|
scope=WsPopupScope(type="project"),
|
|
)
|
|
assert frame.scope.id is None
|
|
|
|
|
|
def test_popup_request_serializes():
|
|
frame = WsPopupRequest(
|
|
message="Test",
|
|
scope=WsPopupScope(type="note", id="n-1"),
|
|
)
|
|
data = frame.model_dump()
|
|
assert data["type"] == "popup_request"
|
|
assert data["scope"]["type"] == "note"
|
|
assert data["scope"]["id"] == "n-1"
|
|
|
|
|
|
def test_popup_request_invalid_scope_type():
|
|
with pytest.raises(ValidationError):
|
|
WsPopupRequest(
|
|
message="X",
|
|
scope=WsPopupScope(type="unknown"), # type: ignore[arg-type]
|
|
)
|
|
|
|
|
|
def test_popup_request_requires_scope():
|
|
with pytest.raises(ValidationError):
|
|
WsPopupRequest.model_validate({"type": "popup_request", "message": "X"})
|
|
|
|
|
|
# ── WsStreamStart ─────────────────────────────────────────────────────
|
|
|
|
|
|
def test_stream_start():
|
|
frame = WsStreamStart(request_id="req-abc")
|
|
assert frame.type == WsFrameType.stream_start
|
|
assert frame.request_id == "req-abc"
|
|
|
|
|
|
def test_stream_start_serializes():
|
|
data = WsStreamStart(request_id="r1").model_dump()
|
|
assert data == {"type": "stream_start", "request_id": "r1"}
|
|
|
|
|
|
def test_stream_start_deserializes():
|
|
frame = WsStreamStart.model_validate({"type": "stream_start", "request_id": "r1"})
|
|
assert frame.request_id == "r1"
|
|
|
|
|
|
# ── WsStreamText ──────────────────────────────────────────────────────
|
|
|
|
|
|
def test_stream_text():
|
|
frame = WsStreamText(request_id="r1", chunk="Hello ")
|
|
assert frame.type == WsFrameType.stream_text
|
|
assert frame.chunk == "Hello "
|
|
|
|
|
|
def test_stream_text_serializes():
|
|
data = WsStreamText(request_id="r1", chunk="word").model_dump()
|
|
assert data == {"type": "stream_text", "request_id": "r1", "chunk": "word"}
|
|
|
|
|
|
def test_stream_text_deserializes():
|
|
raw = {"type": "stream_text", "request_id": "r2", "chunk": "test"}
|
|
frame = WsStreamText.model_validate(raw)
|
|
assert frame.chunk == "test"
|
|
|
|
|
|
# ── WsStreamBlock ─────────────────────────────────────────────────────
|
|
|
|
|
|
def test_stream_block_chart():
|
|
data = {
|
|
"type": "chart",
|
|
"chartType": "bar",
|
|
"title": "Tasks",
|
|
"data": [{"name": "Done", "count": 5}],
|
|
"config": {"count": {"label": "Count", "color": "#4f46e5"}},
|
|
}
|
|
frame = WsStreamBlock(request_id="r1", block_type="chart", data=data)
|
|
assert frame.type == WsFrameType.stream_block
|
|
assert frame.block_type == "chart"
|
|
assert frame.data["chartType"] == "bar"
|
|
|
|
|
|
def test_stream_block_entity_ref():
|
|
frame = WsStreamBlock(
|
|
request_id="r1",
|
|
block_type="entity_ref",
|
|
data={"type": "task", "id": "t-1", "title": "Fix bug"},
|
|
)
|
|
assert frame.block_type == "entity_ref"
|
|
|
|
|
|
def test_stream_block_table():
|
|
frame = WsStreamBlock(
|
|
request_id="r1",
|
|
block_type="table",
|
|
data={"headers": ["A", "B"], "rows": [["1", "2"]]},
|
|
)
|
|
assert frame.block_type == "table"
|
|
|
|
|
|
def test_stream_block_timeline():
|
|
frame = WsStreamBlock(
|
|
request_id="r1",
|
|
block_type="timeline",
|
|
data={"checkpoints": [{"id": "c1", "title": "Launch", "date": 1700000000}]},
|
|
)
|
|
assert frame.block_type == "timeline"
|
|
|
|
|
|
def test_stream_block_invalid_type():
|
|
with pytest.raises(ValidationError):
|
|
WsStreamBlock(
|
|
request_id="r1",
|
|
block_type="unknown", # type: ignore[arg-type]
|
|
data={},
|
|
)
|
|
|
|
|
|
def test_stream_block_serializes():
|
|
frame = WsStreamBlock(request_id="r1", block_type="table", data={"headers": [], "rows": []})
|
|
d = frame.model_dump()
|
|
assert d["type"] == "stream_block"
|
|
assert d["block_type"] == "table"
|
|
|
|
|
|
# ── WsStreamEnd ───────────────────────────────────────────────────────
|
|
|
|
|
|
def test_stream_end_defaults():
|
|
frame = WsStreamEnd(request_id="r1")
|
|
assert frame.type == WsFrameType.stream_end
|
|
assert frame.mutations == []
|
|
|
|
|
|
def test_stream_end_with_mutations():
|
|
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
|
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
|
assert len(frame.mutations) == 1
|
|
assert frame.mutations[0]["action"] == "create"
|
|
|
|
|
|
def test_stream_end_serializes():
|
|
data = WsStreamEnd(request_id="r2").model_dump()
|
|
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
|
|
|
|
|
def test_stream_end_deserializes():
|
|
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
|
frame = WsStreamEnd.model_validate(raw)
|
|
assert frame.request_id == "r3"
|
|
|
|
|
|
# ── WsPopupDomain ─────────────────────────────────────────────────────
|
|
|
|
|
|
def test_popup_domain_tasks():
|
|
frame = WsPopupDomain(request_id="r1", domain="tasks")
|
|
assert frame.type == WsFrameType.popup_domain
|
|
assert frame.domain == "tasks"
|
|
|
|
|
|
@pytest.mark.parametrize("domain", ["tasks", "checkpoints", "notes", "projects"])
|
|
def test_popup_domain_valid_domains(domain: str):
|
|
frame = WsPopupDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
|
assert frame.domain == domain
|
|
|
|
|
|
def test_popup_domain_invalid():
|
|
with pytest.raises(ValidationError):
|
|
WsPopupDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
|
|
|
|
|
def test_popup_domain_serializes():
|
|
d = WsPopupDomain(request_id="r1", domain="notes").model_dump()
|
|
assert d == {"type": "popup_domain", "request_id": "r1", "domain": "notes"}
|
|
|
|
|
|
def test_popup_domain_deserializes():
|
|
raw = {"type": "popup_domain", "request_id": "r1", "domain": "projects"}
|
|
frame = WsPopupDomain.model_validate(raw)
|
|
assert frame.domain == "projects"
|