WsDeviceHello.agent_ids → scout_ids in Pydantic schema, device_ws.py handler, and all test fixtures (test_device_ws, test_ws_unified, test_memory_middleware). Also fixes stale CloudAgentConfig reference in gmail.py docstring. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
290 lines
9.7 KiB
Python
290 lines
9.7 KiB
Python
"""Tests for Step 3.3: DeviceConnectionManager and device WS endpoint.
|
|
|
|
Coverage:
|
|
Unit tests — DeviceConnectionManager register/unregister/is_online/
|
|
get_ws/send_frame/pending-call round-trip/agent-data queue
|
|
Integration — /api/v1/ws/device endpoint via TestClient WebSocket:
|
|
auth rejection, happy-path connect, tool_result dispatch,
|
|
agent_data queue routing, agent_complete sentinel, disconnect
|
|
cleanup (AgentRunLog marked as error)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.core.device_manager import DeviceConnectionManager
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from app.models import ScoutRunLog
|
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_FREE_UID = TEST_USER_IDS["free"]
|
|
_PRO_UID = TEST_USER_IDS["pro"]
|
|
|
|
|
|
def _device_hello(device_id: str = "dev-001", scout_ids: list[str] | None = None) -> str:
|
|
return json.dumps(
|
|
{"type": "device_hello", "device_id": device_id, "scout_ids": scout_ids or []}
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DB override (shared across integration tests)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
"""Route all get_session calls to the test SQLite session."""
|
|
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DeviceConnectionManager unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.fixture()
|
|
def manager() -> DeviceConnectionManager:
|
|
"""Fresh manager instance for each test."""
|
|
return DeviceConnectionManager()
|
|
|
|
|
|
@pytest.fixture()
|
|
def mock_ws() -> MagicMock:
|
|
ws = MagicMock()
|
|
ws.send_text = AsyncMock()
|
|
return ws
|
|
|
|
|
|
def test_manager_register_and_is_online(manager, mock_ws):
|
|
assert not manager.is_online("user1")
|
|
manager.register("user1", "dev-A", mock_ws)
|
|
assert manager.is_online("user1")
|
|
assert manager.is_online("user1", "dev-A")
|
|
assert not manager.is_online("user1", "dev-B")
|
|
|
|
|
|
def test_manager_get_ws_returns_none_when_offline(manager):
|
|
assert manager.get_ws("no-such-user") is None
|
|
|
|
|
|
def test_manager_unregister(manager, mock_ws):
|
|
manager.register("user1", "dev-A", mock_ws)
|
|
assert manager.is_online("user1")
|
|
manager.unregister("user1")
|
|
assert not manager.is_online("user1")
|
|
assert manager.get_ws("user1") is None
|
|
|
|
|
|
def test_manager_unregister_unknown_is_noop(manager):
|
|
# Must not raise.
|
|
manager.unregister("ghost")
|
|
|
|
|
|
def test_manager_replace_connection_cancels_old_futures(manager):
|
|
ws_a = MagicMock()
|
|
ws_a.send_text = AsyncMock()
|
|
ws_b = MagicMock()
|
|
ws_b.send_text = AsyncMock()
|
|
|
|
# Create event loop context for Future.
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
async def _run():
|
|
manager.register("user1", "dev-A", ws_a)
|
|
fut = manager.create_pending_call("user1", "call-1")
|
|
# Replace connection — old future should be cancelled.
|
|
manager.register("user1", "dev-B", ws_b)
|
|
assert fut.cancelled()
|
|
|
|
loop.run_until_complete(_run())
|
|
finally:
|
|
loop.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_manager_send_frame(manager, mock_ws):
|
|
manager.register("user1", "dev-A", mock_ws)
|
|
await manager.send_frame("user1", {"type": "ping"})
|
|
mock_ws.send_text.assert_called_once_with(json.dumps({"type": "ping"}))
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_manager_send_frame_raises_when_offline(manager):
|
|
with pytest.raises(RuntimeError, match="not connected"):
|
|
await manager.send_frame("ghost", {"type": "ping"})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_manager_pending_call_round_trip(manager, mock_ws):
|
|
manager.register("user1", "dev-A", mock_ws)
|
|
fut = manager.create_pending_call("user1", "call-42")
|
|
result = {"type": "tool_result", "id": "call-42", "rows": [{"id": "row1"}]}
|
|
manager.resolve_pending_call("user1", "call-42", result)
|
|
assert fut.done()
|
|
assert await fut == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_manager_resolve_unknown_call_is_noop(manager, mock_ws):
|
|
manager.register("user1", "dev-A", mock_ws)
|
|
# Should not raise.
|
|
manager.resolve_pending_call("user1", "no-such-call", {})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_manager_unregister_cancels_pending_calls(manager, mock_ws):
|
|
manager.register("user1", "dev-A", mock_ws)
|
|
fut = manager.create_pending_call("user1", "call-1")
|
|
manager.unregister("user1")
|
|
assert fut.cancelled()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration tests — /api/v1/ws/device endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_ws_device_rejects_without_token(client):
|
|
with pytest.raises(Exception):
|
|
# TestClient will raise or close when the server rejects.
|
|
with client.websocket_connect("/api/v1/ws/device") as ws:
|
|
ws.receive_text()
|
|
|
|
|
|
def test_ws_device_rejects_invalid_token(client):
|
|
with pytest.raises(Exception):
|
|
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
|
|
ws.receive_text()
|
|
|
|
|
|
def test_ws_device_happy_path(client):
|
|
"""Connect, send device_hello, receive ping, then close."""
|
|
token = make_jwt(tier="free")
|
|
|
|
# Patch the heartbeat sleep so the test doesn't block 30 s.
|
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.01):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(_device_hello("dev-001"))
|
|
# Next message from server should be a heartbeat ping (interval=0.01s).
|
|
msg = ws.receive_text()
|
|
data = json.loads(msg)
|
|
assert data["type"] == "ping"
|
|
# Close gracefully.
|
|
ws.close()
|
|
|
|
|
|
def test_ws_device_invalid_first_frame_closes(client):
|
|
"""Non-device_hello first frame should close the connection."""
|
|
token = make_jwt(tier="free")
|
|
with pytest.raises(Exception):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(json.dumps({"type": "chat_request", "message": "hi"}))
|
|
ws.receive_text() # server should close after bad frame
|
|
|
|
|
|
def test_ws_device_tool_result_dispatched(client):
|
|
"""tool_result frame is routed to the DeviceConnectionManager."""
|
|
token = make_jwt(tier="free")
|
|
|
|
from app.core.device_manager import device_manager as dm
|
|
|
|
captured: list[dict] = []
|
|
|
|
original_resolve = dm.resolve_pending_call
|
|
|
|
def _spy(uid, call_id, result):
|
|
captured.append({"uid": uid, "call_id": call_id, "result": result})
|
|
original_resolve(uid, call_id, result)
|
|
|
|
with patch.object(dm, "resolve_pending_call", side_effect=_spy):
|
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(_device_hello("dev-001"))
|
|
# Send a tool_result frame.
|
|
ws.send_text(
|
|
json.dumps(
|
|
{
|
|
"type": "tool_result",
|
|
"id": "call-123",
|
|
"rows": [{"id": "task-1", "title": "Buy milk"}],
|
|
}
|
|
)
|
|
)
|
|
ws.close()
|
|
|
|
assert any(c["call_id"] == "call-123" for c in captured)
|
|
|
|
|
|
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
|
|
"""On disconnect, _mark_runs_disconnected is called with the correct user_id."""
|
|
from app.api.routes import device_ws as _dws
|
|
|
|
token = make_jwt(tier="free")
|
|
user_id = TEST_USER_IDS["free"]
|
|
|
|
cleanup_calls: list[str] = []
|
|
|
|
async def _fake_cleanup(uid: str) -> None:
|
|
cleanup_calls.append(uid)
|
|
|
|
with patch.object(_dws, "_mark_runs_disconnected", side_effect=_fake_cleanup):
|
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(_device_hello("dev-001"))
|
|
ws.close()
|
|
|
|
assert user_id in cleanup_calls
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mark_runs_disconnected_updates_db(db_session):
|
|
"""_mark_runs_disconnected marks in-progress runs as error in the DB."""
|
|
from sqlalchemy import select
|
|
|
|
from app.api.routes.device_ws import _mark_runs_disconnected
|
|
from tests.conftest import _TestSessionLocal
|
|
|
|
user_id = TEST_USER_IDS["free"]
|
|
|
|
run_log = ScoutRunLog(
|
|
id=str(uuid.uuid4()),
|
|
scout_id=str(uuid.uuid4()),
|
|
scout_type="local",
|
|
user_id=user_id,
|
|
status="running",
|
|
started_at=datetime.now(timezone.utc),
|
|
)
|
|
db_session.add(run_log)
|
|
await db_session.commit()
|
|
|
|
# Route the function to the same test-DB session factory.
|
|
with patch("app.api.routes.device_ws.async_session", _TestSessionLocal):
|
|
await _mark_runs_disconnected(user_id)
|
|
|
|
# Verify through the same session factory.
|
|
async with _TestSessionLocal() as s:
|
|
result = await s.execute(
|
|
select(ScoutRunLog).where(ScoutRunLog.id == run_log.id)
|
|
)
|
|
updated = result.scalar_one_or_none()
|
|
|
|
assert updated is not None
|
|
assert updated.status == "error"
|
|
assert updated.errors and "device disconnected" in updated.errors
|