From b61ded845812c8f2b32f6fe47b25afda93482b0d Mon Sep 17 00:00:00 2001 From: roberto Date: Sun, 8 Mar 2026 21:21:03 +0100 Subject: [PATCH] step-1: add v3 ws frame protocol (schemas.py) Co-Authored-By: Claude Sonnet 4.6 --- V3_MIGRATION_PLAN.md | 56 ++++++++ app/config/settings.py | 1 + app/schemas.py | 77 +++++++++++ tests/test_schemas_v3.py | 292 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 426 insertions(+) create mode 100644 tests/test_schemas_v3.py diff --git a/V3_MIGRATION_PLAN.md b/V3_MIGRATION_PLAN.md index c8b565f..26844fa 100644 --- a/V3_MIGRATION_PLAN.md +++ b/V3_MIGRATION_PLAN.md @@ -45,6 +45,14 @@ pytest tests/test_schemas_v3.py ``` +**Status**: +- [x] Step 1 complete + +**Commit**: After tests pass, commit with: +``` +git commit -m "step-1: add v3 ws frame protocol (schemas.py)" +``` + --- ## Step 2 — Agent Streaming + Tool Result Capture (agent_registry.py, agents/) @@ -65,6 +73,14 @@ pytest tests/test_schemas_v3.py pytest tests/test_agent_streaming.py ``` +**Status**: +- [ ] Step 2 complete + +**Commit**: After tests pass, commit with: +``` +git commit -m "step-2: add agent streaming and tool result capture (agent_registry.py)" +``` + --- ## Step 3 — Router Refactor (orchestrator.py) @@ -90,6 +106,14 @@ pytest tests/test_agent_streaming.py pytest tests/test_orchestrator_v3.py ``` +**Status**: +- [ ] Step 3 complete + +**Commit**: After tests pass, commit with: +``` +git commit -m "step-3: add router refactor with streaming support (orchestrator.py)" +``` + --- ## Step 4 — Output Formatting Layer (NEW: output_formatter.py) @@ -175,6 +199,14 @@ Supported entity types (matching Electron component types): pytest tests/test_output_formatter.py ``` +**Status**: +- [ ] Step 4 complete + +**Commit**: After tests pass, commit with: +``` +git commit -m "step-4: add output formatting layer (output_formatter.py)" +``` + --- ## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py) @@ -207,6 +239,14 @@ pytest tests/test_output_formatter.py pytest tests/test_ws_unified.py ``` +**Status**: +- [ ] Step 5 complete + +**Commit**: After tests pass, commit with: +``` +git commit -m "step-5: unify ws handler (device_ws.py, chat.py)" +``` + --- ## Step 6 — Memory Models + Migration (models.py, alembic) @@ -231,6 +271,14 @@ alembic upgrade head && alembic downgrade -1 && alembic upgrade head pytest tests/test_memory_models.py ``` +**Status**: +- [ ] Step 6 complete + +**Commit**: After tests pass, commit with: +``` +git commit -m "step-6: add memory models and migration (models.py, alembic)" +``` + --- ## Step 7 — Memory Middleware (NEW: memory_middleware.py) @@ -266,6 +314,14 @@ pytest tests/test_memory_models.py pytest tests/test_memory_middleware.py ``` +**Status**: +- [ ] Step 7 complete + +**Commit**: After tests pass, commit with: +``` +git commit -m "step-7: add memory middleware (memory_middleware.py, device_ws.py)" +``` + --- ## Summary diff --git a/app/config/settings.py b/app/config/settings.py index 886d2e5..dd8b292 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -26,6 +26,7 @@ class Settings(BaseSettings): OPENAI_API_KEY: str = "" ANTHROPIC_API_KEY: str = "" GOOGLE_API_KEY: str = "" + CEREBRAS_API_KEY: str = "" LLM_MODEL: str = "gpt-4o" LLM_ROUTER_MODEL: str = "gpt-4o-mini" diff --git a/app/schemas.py b/app/schemas.py index 8ec4075..e5528fa 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -161,6 +161,7 @@ class PluginInstallRequest(BaseModel): # ── WebSocket Frame Protocol ────────────────────────────────────────── class WsFrameType(str, Enum): + # ── v2 frame types (kept for backward compat) ────────────────────── chat_request = "chat_request" text_chunk = "text_chunk" tool_call = "tool_call" @@ -171,6 +172,17 @@ class WsFrameType(str, Enum): agent_data = "agent_data" agent_complete = "agent_complete" device_hello = "device_hello" + # ── v3 frame types ───────────────────────────────────────────────── + home_request = "home_request" + popup_request = "popup_request" + stream_start = "stream_start" + stream_text = "stream_text" + stream_block = "stream_block" + stream_end = "stream_end" + popup_domain = "popup_domain" + data_request = "data_request" + data_response = "data_response" + mutation = "mutation" class WsToolCall(BaseModel): @@ -249,6 +261,71 @@ class WsAgentComplete(BaseModel): errors: list[str] = Field(default_factory=list) +# ── WebSocket v3 Frame Models ───────────────────────────────────────── + +class WsPopupScope(BaseModel): + """Scope for a popup request — narrows the agent to a specific entity.""" + + type: Literal["task", "project", "note", "checkpoint"] + id: str | None = None + + +class WsHomeRequest(BaseModel): + """Client → Server: Home chat message.""" + + type: Literal[WsFrameType.home_request] = WsFrameType.home_request + message: str + conversation_history: list[dict[str, Any]] = Field(default_factory=list) + + +class WsPopupRequest(BaseModel): + """Client → Server: Popup chat message scoped to an entity.""" + + type: Literal[WsFrameType.popup_request] = WsFrameType.popup_request + message: str + scope: WsPopupScope + + +class WsStreamStart(BaseModel): + """Server → Client: signals start of a streaming response.""" + + type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start + request_id: str + + +class WsStreamText(BaseModel): + """Server → Client: streamed text token.""" + + type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text + request_id: str + chunk: str + + +class WsStreamBlock(BaseModel): + """Server → Client: structured block (chart, table, entity, timeline).""" + + type: Literal[WsFrameType.stream_block] = WsFrameType.stream_block + request_id: str + block_type: Literal["chart", "entity_ref", "table", "timeline"] + data: dict[str, Any] + + +class WsStreamEnd(BaseModel): + """Server → Client: signals end of a streaming response.""" + + type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end + request_id: str + mutations: list[dict[str, Any]] = Field(default_factory=list) + + +class WsPopupDomain(BaseModel): + """Server → Client: domain determined for a popup request.""" + + type: Literal[WsFrameType.popup_domain] = WsFrameType.popup_domain + request_id: str + domain: Literal["tasks", "checkpoints", "notes", "projects"] + + # ── Agent Catalog ───────────────────────────────────────────────────── class AgentCatalogItem(BaseModel): diff --git a/tests/test_schemas_v3.py b/tests/test_schemas_v3.py new file mode 100644 index 0000000..69d62cf --- /dev/null +++ b/tests/test_schemas_v3.py @@ -0,0 +1,292 @@ +"""Tests for v3 WebSocket frame protocol schemas.""" + +import pytest +from pydantic import ValidationError + +from app.schemas import ( + WsFrameType, + WsHomeRequest, + WsPopupDomain, + WsPopupRequest, + WsPopupScope, + WsStreamBlock, + WsStreamEnd, + WsStreamStart, + WsStreamText, +) + + +# ── WsFrameType ─────────────────────────────────────────────────────── + + +def test_v3_frame_types_exist(): + v3_types = [ + "home_request", + "popup_request", + "stream_start", + "stream_text", + "stream_block", + "stream_end", + "popup_domain", + "data_request", + "data_response", + "mutation", + ] + for name in v3_types: + assert hasattr(WsFrameType, name), f"WsFrameType missing: {name}" + assert WsFrameType[name].value == name + + +def test_v2_frame_types_still_exist(): + """Backward compat: v2 types must remain.""" + v2_types = [ + "chat_request", + "text_chunk", + "tool_call", + "tool_result", + "final", + "ping", + "agent_run", + "agent_data", + "agent_complete", + "device_hello", + ] + for name in v2_types: + assert hasattr(WsFrameType, name), f"v2 WsFrameType missing: {name}" + + +# ── WsHomeRequest ───────────────────────────────────────────────────── + + +def test_home_request_defaults(): + frame = WsHomeRequest(message="Hello") + assert frame.type == WsFrameType.home_request + assert frame.message == "Hello" + assert frame.conversation_history == [] + + +def test_home_request_with_history(): + history = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}] + frame = WsHomeRequest(message="Follow up", conversation_history=history) + assert frame.conversation_history == history + + +def test_home_request_serializes(): + frame = WsHomeRequest(message="Test") + data = frame.model_dump() + assert data["type"] == "home_request" + assert data["message"] == "Test" + assert data["conversation_history"] == [] + + +def test_home_request_deserializes(): + raw = {"type": "home_request", "message": "Hi there"} + frame = WsHomeRequest.model_validate(raw) + assert frame.message == "Hi there" + + +def test_home_request_requires_message(): + with pytest.raises(ValidationError): + WsHomeRequest.model_validate({"type": "home_request"}) + + +# ── WsPopupRequest ──────────────────────────────────────────────────── + + +def test_popup_request_basic(): + frame = WsPopupRequest( + message="Summarise", + scope=WsPopupScope(type="task", id="task-123"), + ) + assert frame.type == WsFrameType.popup_request + assert frame.scope.type == "task" + assert frame.scope.id == "task-123" + + +def test_popup_request_scope_without_id(): + frame = WsPopupRequest( + message="Show all", + scope=WsPopupScope(type="project"), + ) + assert frame.scope.id is None + + +def test_popup_request_serializes(): + frame = WsPopupRequest( + message="Test", + scope=WsPopupScope(type="note", id="n-1"), + ) + data = frame.model_dump() + assert data["type"] == "popup_request" + assert data["scope"]["type"] == "note" + assert data["scope"]["id"] == "n-1" + + +def test_popup_request_invalid_scope_type(): + with pytest.raises(ValidationError): + WsPopupRequest( + message="X", + scope=WsPopupScope(type="unknown"), # type: ignore[arg-type] + ) + + +def test_popup_request_requires_scope(): + with pytest.raises(ValidationError): + WsPopupRequest.model_validate({"type": "popup_request", "message": "X"}) + + +# ── WsStreamStart ───────────────────────────────────────────────────── + + +def test_stream_start(): + frame = WsStreamStart(request_id="req-abc") + assert frame.type == WsFrameType.stream_start + assert frame.request_id == "req-abc" + + +def test_stream_start_serializes(): + data = WsStreamStart(request_id="r1").model_dump() + assert data == {"type": "stream_start", "request_id": "r1"} + + +def test_stream_start_deserializes(): + frame = WsStreamStart.model_validate({"type": "stream_start", "request_id": "r1"}) + assert frame.request_id == "r1" + + +# ── WsStreamText ────────────────────────────────────────────────────── + + +def test_stream_text(): + frame = WsStreamText(request_id="r1", chunk="Hello ") + assert frame.type == WsFrameType.stream_text + assert frame.chunk == "Hello " + + +def test_stream_text_serializes(): + data = WsStreamText(request_id="r1", chunk="word").model_dump() + assert data == {"type": "stream_text", "request_id": "r1", "chunk": "word"} + + +def test_stream_text_deserializes(): + raw = {"type": "stream_text", "request_id": "r2", "chunk": "test"} + frame = WsStreamText.model_validate(raw) + assert frame.chunk == "test" + + +# ── WsStreamBlock ───────────────────────────────────────────────────── + + +def test_stream_block_chart(): + data = { + "type": "chart", + "chartType": "bar", + "title": "Tasks", + "data": [{"name": "Done", "count": 5}], + "config": {"count": {"label": "Count", "color": "#4f46e5"}}, + } + frame = WsStreamBlock(request_id="r1", block_type="chart", data=data) + assert frame.type == WsFrameType.stream_block + assert frame.block_type == "chart" + assert frame.data["chartType"] == "bar" + + +def test_stream_block_entity_ref(): + frame = WsStreamBlock( + request_id="r1", + block_type="entity_ref", + data={"type": "task", "id": "t-1", "title": "Fix bug"}, + ) + assert frame.block_type == "entity_ref" + + +def test_stream_block_table(): + frame = WsStreamBlock( + request_id="r1", + block_type="table", + data={"headers": ["A", "B"], "rows": [["1", "2"]]}, + ) + assert frame.block_type == "table" + + +def test_stream_block_timeline(): + frame = WsStreamBlock( + request_id="r1", + block_type="timeline", + data={"checkpoints": [{"id": "c1", "title": "Launch", "date": 1700000000}]}, + ) + assert frame.block_type == "timeline" + + +def test_stream_block_invalid_type(): + with pytest.raises(ValidationError): + WsStreamBlock( + request_id="r1", + block_type="unknown", # type: ignore[arg-type] + data={}, + ) + + +def test_stream_block_serializes(): + frame = WsStreamBlock(request_id="r1", block_type="table", data={"headers": [], "rows": []}) + d = frame.model_dump() + assert d["type"] == "stream_block" + assert d["block_type"] == "table" + + +# ── WsStreamEnd ─────────────────────────────────────────────────────── + + +def test_stream_end_defaults(): + frame = WsStreamEnd(request_id="r1") + assert frame.type == WsFrameType.stream_end + assert frame.mutations == [] + + +def test_stream_end_with_mutations(): + mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}] + frame = WsStreamEnd(request_id="r1", mutations=mutations) + assert len(frame.mutations) == 1 + assert frame.mutations[0]["action"] == "create" + + +def test_stream_end_serializes(): + data = WsStreamEnd(request_id="r2").model_dump() + assert data == {"type": "stream_end", "request_id": "r2", "mutations": []} + + +def test_stream_end_deserializes(): + raw = {"type": "stream_end", "request_id": "r3", "mutations": []} + frame = WsStreamEnd.model_validate(raw) + assert frame.request_id == "r3" + + +# ── WsPopupDomain ───────────────────────────────────────────────────── + + +def test_popup_domain_tasks(): + frame = WsPopupDomain(request_id="r1", domain="tasks") + assert frame.type == WsFrameType.popup_domain + assert frame.domain == "tasks" + + +@pytest.mark.parametrize("domain", ["tasks", "checkpoints", "notes", "projects"]) +def test_popup_domain_valid_domains(domain: str): + frame = WsPopupDomain(request_id="r1", domain=domain) # type: ignore[arg-type] + assert frame.domain == domain + + +def test_popup_domain_invalid(): + with pytest.raises(ValidationError): + WsPopupDomain(request_id="r1", domain="invalid") # type: ignore[arg-type] + + +def test_popup_domain_serializes(): + d = WsPopupDomain(request_id="r1", domain="notes").model_dump() + assert d == {"type": "popup_domain", "request_id": "r1", "domain": "notes"} + + +def test_popup_domain_deserializes(): + raw = {"type": "popup_domain", "request_id": "r1", "domain": "projects"} + frame = WsPopupDomain.model_validate(raw) + assert frame.domain == "projects"