- device_ws.py: dispatch home_request/popup_request to HomeFormatter/PopupFormatter via async tasks; each request gets a UUID request_id for frame correlation - chat.py: remove chat_stream WS endpoint (superseded by unified device WS); keep POST /chat REST fallback unchanged - 5 new integration tests pass; all 22 existing device_ws tests still pass Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
158 lines
5.8 KiB
Python
158 lines
5.8 KiB
Python
"""Integration tests for the unified WebSocket handler (Step 5).
|
|
|
|
Tests the device WS endpoint with home_request and popup_request frames,
|
|
verifying that the correct v3 frame sequence is returned.
|
|
|
|
LLM calls are mocked to avoid network dependency.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from app.schemas import WsFrameType
|
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
|
|
|
USER_ID = TEST_USER_IDS["power"]
|
|
|
|
|
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
|
"""Receive frames until stream_end (or stream_end inside popup flow), or max_frames."""
|
|
frames = []
|
|
for _ in range(max_frames):
|
|
raw = ws.receive_text()
|
|
frame = json.loads(raw)
|
|
frames.append(frame)
|
|
if frame.get("type") == WsFrameType.stream_end:
|
|
break
|
|
return frames
|
|
|
|
|
|
async def _mock_home_stream(user_id, message, context, reg=None):
|
|
yield "task_agent", ""
|
|
yield "task_agent", '{"type": "text", "content": "Hello"}'
|
|
|
|
|
|
async def _mock_popup_stream(user_id, message, context, reg=None):
|
|
yield "task_agent", ""
|
|
yield "task_agent", "Here is a summary"
|
|
|
|
|
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
def test_home_request_produces_stream_frames(client):
|
|
"""home_request → stream_start, stream_text+, stream_end."""
|
|
token = make_jwt("power", user_id=USER_ID)
|
|
|
|
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(json.dumps({
|
|
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
|
|
}))
|
|
ws.send_text(json.dumps({
|
|
"type": "home_request",
|
|
"request_id": "r1",
|
|
"message": "List my tasks",
|
|
"conversation_history": [],
|
|
}))
|
|
frames = _recv_until_end(ws)
|
|
|
|
types = [f["type"] for f in frames]
|
|
assert WsFrameType.stream_start in types
|
|
assert WsFrameType.stream_end in types
|
|
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
|
|
|
|
|
|
def test_popup_request_produces_domain_frame(client):
|
|
"""popup_request → popup_domain first, then stream_text*, stream_end."""
|
|
token = make_jwt("power", user_id=USER_ID)
|
|
|
|
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_popup_stream):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(json.dumps({
|
|
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
|
}))
|
|
ws.send_text(json.dumps({
|
|
"type": "popup_request",
|
|
"request_id": "p1",
|
|
"message": "Summarize this task",
|
|
"scope": {"type": "task", "id": "task-123"},
|
|
}))
|
|
frames = _recv_until_end(ws)
|
|
|
|
types = [f["type"] for f in frames]
|
|
assert WsFrameType.popup_domain in types
|
|
assert WsFrameType.stream_end in types
|
|
assert types.index(WsFrameType.popup_domain) < types.index(WsFrameType.stream_end)
|
|
|
|
domain_frame = next(f for f in frames if f["type"] == WsFrameType.popup_domain)
|
|
assert domain_frame["domain"] == "tasks"
|
|
assert domain_frame["request_id"] == "p1"
|
|
|
|
|
|
def test_home_request_request_id_propagated(client):
|
|
"""request_id in home_request is echoed in all response frames."""
|
|
token = make_jwt("power", user_id=USER_ID)
|
|
req_id = "my-unique-req-id"
|
|
|
|
async def _stream(user_id, message, context, reg=None):
|
|
yield "note_agent", ""
|
|
yield "note_agent", '{"type": "text", "content": "ok"}'
|
|
|
|
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(json.dumps({
|
|
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
|
|
}))
|
|
ws.send_text(json.dumps({
|
|
"type": "home_request",
|
|
"request_id": req_id,
|
|
"message": "hello",
|
|
}))
|
|
frames = _recv_until_end(ws)
|
|
|
|
for f in frames:
|
|
if "request_id" in f:
|
|
assert f["request_id"] == req_id
|
|
|
|
|
|
def test_tool_result_dispatch_silent_on_unknown_id(client):
|
|
"""tool_result for unknown call_id is silently ignored — no crash."""
|
|
token = make_jwt("power", user_id=USER_ID)
|
|
|
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(json.dumps({
|
|
"type": "device_hello", "device_id": "dev-4", "agent_ids": []
|
|
}))
|
|
ws.send_text(json.dumps({
|
|
"type": "tool_result", "id": "no-such-id", "ok": True
|
|
}))
|
|
# If connection is still alive, we'll get the heartbeat ping
|
|
msg = json.loads(ws.receive_text())
|
|
assert msg["type"] == "ping"
|
|
|
|
|
|
def test_invalid_jwt_rejected(client):
|
|
"""Connection with bad token is closed before or after accept."""
|
|
with pytest.raises(Exception):
|
|
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
|
|
ws.receive_text()
|