"""Integration tests for the unified WebSocket handler (Step 5). Tests the device WS endpoint with home_request and popup_request frames, verifying that the correct v3 frame sequence is returned. LLM calls are mocked to avoid network dependency. """ from __future__ import annotations import json from unittest.mock import patch import pytest from app.db import get_session from app.main import app from app.schemas import WsFrameType from tests.conftest import TEST_USER_IDS, make_jwt USER_ID = TEST_USER_IDS["power"] # ── helpers ─────────────────────────────────────────────────────────────────── @pytest.fixture(autouse=True) def _override_db(db_session): async def _gen(): yield db_session app.dependency_overrides[get_session] = _gen yield app.dependency_overrides.pop(get_session, None) def _recv_until_end(ws, max_frames: int = 20) -> list[dict]: """Receive frames until stream_end (or stream_end inside popup flow), or max_frames.""" frames = [] for _ in range(max_frames): raw = ws.receive_text() frame = json.loads(raw) frames.append(frame) if frame.get("type") == WsFrameType.stream_end: break return frames async def _mock_home_stream(user_id, message, context, reg=None): yield "task_agent", "" yield "task_agent", '{"type": "text", "content": "Hello"}' async def _mock_popup_stream(user_id, message, context, reg=None): yield "task_agent", "" yield "task_agent", "Here is a summary" # ── tests ───────────────────────────────────────────────────────────────────── def test_home_request_produces_stream_frames(client): """home_request → stream_start, stream_text+, stream_end.""" token = make_jwt("power", user_id=USER_ID) with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({ "type": "device_hello", "device_id": "dev-1", "agent_ids": [] })) ws.send_text(json.dumps({ "type": "home_request", "request_id": "r1", "message": "List my tasks", "conversation_history": [], })) frames = _recv_until_end(ws) types = [f["type"] for f in frames] assert WsFrameType.stream_start in types assert WsFrameType.stream_end in types assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end) def test_popup_request_produces_domain_frame(client): """popup_request → popup_domain first, then stream_text*, stream_end.""" token = make_jwt("power", user_id=USER_ID) with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_popup_stream): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({ "type": "device_hello", "device_id": "dev-2", "agent_ids": [] })) ws.send_text(json.dumps({ "type": "popup_request", "request_id": "p1", "message": "Summarize this task", "scope": {"type": "task", "id": "task-123"}, })) frames = _recv_until_end(ws) types = [f["type"] for f in frames] assert WsFrameType.popup_domain in types assert WsFrameType.stream_end in types assert types.index(WsFrameType.popup_domain) < types.index(WsFrameType.stream_end) domain_frame = next(f for f in frames if f["type"] == WsFrameType.popup_domain) assert domain_frame["domain"] == "tasks" assert domain_frame["request_id"] == "p1" def test_home_request_request_id_propagated(client): """request_id in home_request is echoed in all response frames.""" token = make_jwt("power", user_id=USER_ID) req_id = "my-unique-req-id" async def _stream(user_id, message, context, reg=None): yield "note_agent", "" yield "note_agent", '{"type": "text", "content": "ok"}' with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({ "type": "device_hello", "device_id": "dev-3", "agent_ids": [] })) ws.send_text(json.dumps({ "type": "home_request", "request_id": req_id, "message": "hello", })) frames = _recv_until_end(ws) for f in frames: if "request_id" in f: assert f["request_id"] == req_id def test_tool_result_dispatch_silent_on_unknown_id(client): """tool_result for unknown call_id is silently ignored — no crash.""" token = make_jwt("power", user_id=USER_ID) with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({ "type": "device_hello", "device_id": "dev-4", "agent_ids": [] })) ws.send_text(json.dumps({ "type": "tool_result", "id": "no-such-id", "ok": True })) # If connection is still alive, we'll get the heartbeat ping msg = json.loads(ws.receive_text()) assert msg["type"] == "ping" def test_invalid_jwt_rejected(client): """Connection with bad token is closed before or after accept.""" with pytest.raises(Exception): with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws: ws.receive_text()