231 lines
7.3 KiB
Python
231 lines
7.3 KiB
Python
"""Tests for v3 WebSocket frame protocol schemas."""
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from app.schemas import (
|
|
WsFrameType,
|
|
WsHomeRequest,
|
|
WsFloatingDomain,
|
|
WsFloatingRequest,
|
|
WsFloatingScope,
|
|
WsStreamEnd,
|
|
WsStreamStart,
|
|
WsStreamText,
|
|
)
|
|
|
|
|
|
# ── WsFrameType ───────────────────────────────────────────────────────
|
|
|
|
|
|
def test_v3_frame_types_exist():
|
|
v3_types = [
|
|
"home_request",
|
|
"floating_request",
|
|
"stream_start",
|
|
"stream_text",
|
|
"stream_end",
|
|
"floating_domain",
|
|
"data_request",
|
|
"data_response",
|
|
"mutation",
|
|
]
|
|
for name in v3_types:
|
|
assert hasattr(WsFrameType, name), f"WsFrameType missing: {name}"
|
|
assert WsFrameType[name].value == name
|
|
|
|
|
|
def test_v2_frame_types_still_exist():
|
|
"""Backward compat: v2 types must remain."""
|
|
v2_types = [
|
|
"chat_request",
|
|
"text_chunk",
|
|
"tool_call",
|
|
"tool_result",
|
|
"final",
|
|
"ping",
|
|
"agent_run",
|
|
"agent_data",
|
|
"agent_complete",
|
|
"device_hello",
|
|
]
|
|
for name in v2_types:
|
|
assert hasattr(WsFrameType, name), f"v2 WsFrameType missing: {name}"
|
|
|
|
|
|
# ── WsHomeRequest ─────────────────────────────────────────────────────
|
|
|
|
|
|
def test_home_request_defaults():
|
|
frame = WsHomeRequest(message="Hello")
|
|
assert frame.type == WsFrameType.home_request
|
|
assert frame.message == "Hello"
|
|
assert frame.conversation_history == []
|
|
|
|
|
|
def test_home_request_with_history():
|
|
history = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]
|
|
frame = WsHomeRequest(message="Follow up", conversation_history=history)
|
|
assert frame.conversation_history == history
|
|
|
|
|
|
def test_home_request_serializes():
|
|
frame = WsHomeRequest(message="Test")
|
|
data = frame.model_dump()
|
|
assert data["type"] == "home_request"
|
|
assert data["message"] == "Test"
|
|
assert data["conversation_history"] == []
|
|
|
|
|
|
def test_home_request_deserializes():
|
|
raw = {"type": "home_request", "message": "Hi there"}
|
|
frame = WsHomeRequest.model_validate(raw)
|
|
assert frame.message == "Hi there"
|
|
|
|
|
|
def test_home_request_requires_message():
|
|
with pytest.raises(ValidationError):
|
|
WsHomeRequest.model_validate({"type": "home_request"})
|
|
|
|
|
|
# ── WsFloatingRequest ────────────────────────────────────────────────────
|
|
|
|
|
|
def test_floating_request_basic():
|
|
frame = WsFloatingRequest(
|
|
message="Summarise",
|
|
scope=WsFloatingScope(type="task", id="task-123"),
|
|
)
|
|
assert frame.type == WsFrameType.floating_request
|
|
assert frame.scope.type == "task"
|
|
assert frame.scope.id == "task-123"
|
|
|
|
|
|
def test_floating_request_scope_without_id():
|
|
frame = WsFloatingRequest(
|
|
message="Show all",
|
|
scope=WsFloatingScope(type="project"),
|
|
)
|
|
assert frame.scope.id is None
|
|
|
|
|
|
def test_floating_request_serializes():
|
|
frame = WsFloatingRequest(
|
|
message="Test",
|
|
scope=WsFloatingScope(type="note", id="n-1"),
|
|
)
|
|
data = frame.model_dump()
|
|
assert data["type"] == "floating_request"
|
|
assert data["scope"]["type"] == "note"
|
|
assert data["scope"]["id"] == "n-1"
|
|
|
|
|
|
def test_floating_request_invalid_scope_type():
|
|
with pytest.raises(ValidationError):
|
|
WsFloatingRequest(
|
|
message="X",
|
|
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
|
|
)
|
|
|
|
|
|
def test_floating_request_requires_scope():
|
|
with pytest.raises(ValidationError):
|
|
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
|
|
|
|
|
|
# ── WsStreamStart ─────────────────────────────────────────────────────
|
|
|
|
|
|
def test_stream_start():
|
|
frame = WsStreamStart(request_id="req-abc")
|
|
assert frame.type == WsFrameType.stream_start
|
|
assert frame.request_id == "req-abc"
|
|
|
|
|
|
def test_stream_start_serializes():
|
|
data = WsStreamStart(request_id="r1").model_dump()
|
|
assert data == {"type": "stream_start", "request_id": "r1"}
|
|
|
|
|
|
def test_stream_start_deserializes():
|
|
frame = WsStreamStart.model_validate({"type": "stream_start", "request_id": "r1"})
|
|
assert frame.request_id == "r1"
|
|
|
|
|
|
# ── WsStreamText ──────────────────────────────────────────────────────
|
|
|
|
|
|
def test_stream_text():
|
|
frame = WsStreamText(request_id="r1", chunk="Hello ")
|
|
assert frame.type == WsFrameType.stream_text
|
|
assert frame.chunk == "Hello "
|
|
|
|
|
|
def test_stream_text_serializes():
|
|
data = WsStreamText(request_id="r1", chunk="word").model_dump()
|
|
assert data == {"type": "stream_text", "request_id": "r1", "chunk": "word"}
|
|
|
|
|
|
def test_stream_text_deserializes():
|
|
raw = {"type": "stream_text", "request_id": "r2", "chunk": "test"}
|
|
frame = WsStreamText.model_validate(raw)
|
|
assert frame.chunk == "test"
|
|
|
|
|
|
# ── WsStreamEnd ───────────────────────────────────────────────────────
|
|
|
|
|
|
def test_stream_end_defaults():
|
|
frame = WsStreamEnd(request_id="r1")
|
|
assert frame.type == WsFrameType.stream_end
|
|
assert frame.mutations == []
|
|
|
|
|
|
def test_stream_end_with_mutations():
|
|
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
|
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
|
assert len(frame.mutations) == 1
|
|
assert frame.mutations[0]["action"] == "create"
|
|
|
|
|
|
def test_stream_end_serializes():
|
|
data = WsStreamEnd(request_id="r2").model_dump()
|
|
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
|
|
|
|
|
def test_stream_end_deserializes():
|
|
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
|
frame = WsStreamEnd.model_validate(raw)
|
|
assert frame.request_id == "r3"
|
|
|
|
|
|
# ── WsFloatingDomain ─────────────────────────────────────────────────────
|
|
|
|
|
|
def test_floating_domain_tasks():
|
|
frame = WsFloatingDomain(request_id="r1", domain="tasks")
|
|
assert frame.type == WsFrameType.floating_domain
|
|
assert frame.domain == "tasks"
|
|
|
|
|
|
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
|
|
def test_floating_domain_valid_domains(domain: str):
|
|
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
|
assert frame.domain == domain
|
|
|
|
|
|
def test_floating_domain_invalid():
|
|
with pytest.raises(ValidationError):
|
|
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
|
|
|
|
|
def test_floating_domain_serializes():
|
|
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
|
|
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
|
|
|
|
|
|
def test_floating_domain_deserializes():
|
|
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
|
|
frame = WsFloatingDomain.model_validate(raw)
|
|
assert frame.domain == "projects"
|