"""Tests for Step 3.3: DeviceConnectionManager and device WS endpoint. Coverage: Unit tests — DeviceConnectionManager register/unregister/is_online/ get_ws/send_frame/pending-call round-trip/agent-data queue Integration — /api/v1/ws/device endpoint via TestClient WebSocket: auth rejection, happy-path connect, tool_result dispatch, agent_data queue routing, agent_complete sentinel, disconnect cleanup (AgentRunLog marked as error) """ from __future__ import annotations import asyncio import json import uuid from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.core.device_manager import DeviceConnectionManager from app.db import get_session from app.main import app from app.models import ScoutRunLog from tests.conftest import TEST_USER_IDS, make_jwt # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- _FREE_UID = TEST_USER_IDS["free"] _PRO_UID = TEST_USER_IDS["pro"] def _device_hello(device_id: str = "dev-001", scout_ids: list[str] | None = None) -> str: return json.dumps( {"type": "device_hello", "device_id": device_id, "scout_ids": scout_ids or []} ) # --------------------------------------------------------------------------- # DB override (shared across integration tests) # --------------------------------------------------------------------------- @pytest.fixture(autouse=True) def _override_db(db_session): """Route all get_session calls to the test SQLite session.""" async def _gen(): yield db_session app.dependency_overrides[get_session] = _gen yield app.dependency_overrides.pop(get_session, None) # --------------------------------------------------------------------------- # DeviceConnectionManager unit tests # --------------------------------------------------------------------------- @pytest.fixture() def manager() -> DeviceConnectionManager: """Fresh manager instance for each test.""" return DeviceConnectionManager() @pytest.fixture() def mock_ws() -> MagicMock: ws = MagicMock() ws.send_text = AsyncMock() return ws def test_manager_register_and_is_online(manager, mock_ws): assert not manager.is_online("user1") manager.register("user1", "dev-A", mock_ws) assert manager.is_online("user1") assert manager.is_online("user1", "dev-A") assert not manager.is_online("user1", "dev-B") def test_manager_get_ws_returns_none_when_offline(manager): assert manager.get_ws("no-such-user") is None def test_manager_unregister(manager, mock_ws): manager.register("user1", "dev-A", mock_ws) assert manager.is_online("user1") manager.unregister("user1") assert not manager.is_online("user1") assert manager.get_ws("user1") is None def test_manager_unregister_unknown_is_noop(manager): # Must not raise. manager.unregister("ghost") def test_manager_replace_connection_cancels_old_futures(manager): ws_a = MagicMock() ws_a.send_text = AsyncMock() ws_b = MagicMock() ws_b.send_text = AsyncMock() # Create event loop context for Future. loop = asyncio.new_event_loop() try: async def _run(): manager.register("user1", "dev-A", ws_a) fut = manager.create_pending_call("user1", "call-1") # Replace connection — old future should be cancelled. manager.register("user1", "dev-B", ws_b) assert fut.cancelled() loop.run_until_complete(_run()) finally: loop.close() @pytest.mark.asyncio async def test_manager_send_frame(manager, mock_ws): manager.register("user1", "dev-A", mock_ws) await manager.send_frame("user1", {"type": "ping"}) mock_ws.send_text.assert_called_once_with(json.dumps({"type": "ping"})) @pytest.mark.asyncio async def test_manager_send_frame_raises_when_offline(manager): with pytest.raises(RuntimeError, match="not connected"): await manager.send_frame("ghost", {"type": "ping"}) @pytest.mark.asyncio async def test_manager_pending_call_round_trip(manager, mock_ws): manager.register("user1", "dev-A", mock_ws) fut = manager.create_pending_call("user1", "call-42") result = {"type": "tool_result", "id": "call-42", "rows": [{"id": "row1"}]} manager.resolve_pending_call("user1", "call-42", result) assert fut.done() assert await fut == result @pytest.mark.asyncio async def test_manager_resolve_unknown_call_is_noop(manager, mock_ws): manager.register("user1", "dev-A", mock_ws) # Should not raise. manager.resolve_pending_call("user1", "no-such-call", {}) @pytest.mark.asyncio async def test_manager_unregister_cancels_pending_calls(manager, mock_ws): manager.register("user1", "dev-A", mock_ws) fut = manager.create_pending_call("user1", "call-1") manager.unregister("user1") assert fut.cancelled() # --------------------------------------------------------------------------- # Integration tests — /api/v1/ws/device endpoint # --------------------------------------------------------------------------- def test_ws_device_rejects_without_token(client): with pytest.raises(Exception): # TestClient will raise or close when the server rejects. with client.websocket_connect("/api/v1/ws/device") as ws: ws.receive_text() def test_ws_device_rejects_invalid_token(client): with pytest.raises(Exception): with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws: ws.receive_text() def test_ws_device_happy_path(client): """Connect, send device_hello, receive ping, then close.""" token = make_jwt(tier="free") # Patch the heartbeat sleep so the test doesn't block 30 s. with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.01): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(_device_hello("dev-001")) # Next message from server should be a heartbeat ping (interval=0.01s). msg = ws.receive_text() data = json.loads(msg) assert data["type"] == "ping" # Close gracefully. ws.close() def test_ws_device_invalid_first_frame_closes(client): """Non-device_hello first frame should close the connection.""" token = make_jwt(tier="free") with pytest.raises(Exception): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({"type": "chat_request", "message": "hi"})) ws.receive_text() # server should close after bad frame def test_ws_device_tool_result_dispatched(client): """tool_result frame is routed to the DeviceConnectionManager.""" token = make_jwt(tier="free") from app.core.device_manager import device_manager as dm captured: list[dict] = [] original_resolve = dm.resolve_pending_call def _spy(uid, call_id, result): captured.append({"uid": uid, "call_id": call_id, "result": result}) original_resolve(uid, call_id, result) with patch.object(dm, "resolve_pending_call", side_effect=_spy): with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(_device_hello("dev-001")) # Send a tool_result frame. ws.send_text( json.dumps( { "type": "tool_result", "id": "call-123", "rows": [{"id": "task-1", "title": "Buy milk"}], } ) ) ws.close() assert any(c["call_id"] == "call-123" for c in captured) def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session): """On disconnect, _mark_runs_disconnected is called with the correct user_id.""" from app.api.routes import device_ws as _dws token = make_jwt(tier="free") user_id = TEST_USER_IDS["free"] cleanup_calls: list[str] = [] async def _fake_cleanup(uid: str) -> None: cleanup_calls.append(uid) with patch.object(_dws, "_mark_runs_disconnected", side_effect=_fake_cleanup): with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(_device_hello("dev-001")) ws.close() assert user_id in cleanup_calls @pytest.mark.asyncio async def test_mark_runs_disconnected_updates_db(db_session): """_mark_runs_disconnected marks in-progress runs as error in the DB.""" from sqlalchemy import select from app.api.routes.device_ws import _mark_runs_disconnected from tests.conftest import _TestSessionLocal user_id = TEST_USER_IDS["free"] run_log = ScoutRunLog( id=str(uuid.uuid4()), scout_id=str(uuid.uuid4()), scout_type="local", user_id=user_id, status="running", started_at=datetime.now(timezone.utc), ) db_session.add(run_log) await db_session.commit() # Route the function to the same test-DB session factory. with patch("app.api.routes.device_ws.async_session", _TestSessionLocal): await _mark_runs_disconnected(user_id) # Verify through the same session factory. async with _TestSessionLocal() as s: result = await s.execute( select(ScoutRunLog).where(ScoutRunLog.id == run_log.id) ) updated = result.scalar_one_or_none() assert updated is not None assert updated.status == "error" assert updated.errors and "device disconnected" in updated.errors