From 608d6c784f9cd02d8bc655a53f0c2710a1ed2c2b Mon Sep 17 00:00:00 2001 From: roberto Date: Thu, 5 Mar 2026 15:51:58 +0100 Subject: [PATCH] step 3.3 complete: device WS endpoint + DeviceConnectionManager --- AI_REFACTOR_PLAN.md | 4 +- app/api/routes/device_ws.py | 226 ++++++++++++++++++++++ app/core/device_manager.py | 183 ++++++++++++++++++ app/main.py | 21 ++- tests/test_device_ws.py | 362 ++++++++++++++++++++++++++++++++++++ 5 files changed, 784 insertions(+), 12 deletions(-) create mode 100644 app/api/routes/device_ws.py create mode 100644 app/core/device_manager.py create mode 100644 tests/test_device_ws.py diff --git a/AI_REFACTOR_PLAN.md b/AI_REFACTOR_PLAN.md index 975b93c..72a4b27 100644 --- a/AI_REFACTOR_PLAN.md +++ b/AI_REFACTOR_PLAN.md @@ -348,7 +348,7 @@ Cloud Agent: - **Outcome:** Full CRUD for agent configs with tier-gated creation limits. ### Step 3.3 — Device WS endpoint -- [ ] Create `app/api/routes/device_ws.py`: +- [x] Create `app/api/routes/device_ws.py`: - `WebSocket /api/v1/ws/device?token=` — persistent connection from Electron - On connect: - Authenticate JWT @@ -364,7 +364,7 @@ Cloud Agent: - Remove from `DeviceConnectionManager` - Mark any in-progress agent runs as `error` with "device disconnected" - Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s -- [ ] Create `app/core/device_manager.py`: +- [x] Create `app/core/device_manager.py`: - `DeviceConnectionManager` (singleton): - `register(user_id, device_id, ws)` — stores active connection - `unregister(user_id)` — removes connection diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py new file mode 100644 index 0000000..ffc9e19 --- /dev/null +++ b/app/api/routes/device_ws.py @@ -0,0 +1,226 @@ +"""Device WebSocket endpoint. + +Persistent connection from Electron devices to the backend. + + WS /api/v1/ws/device?token= + +Auth: JWT passed as ``?token=`` query parameter (Bearer header is not +available during the WebSocket handshake). + +Protocol: + 1. Client connects → JWT validated → connection accepted. + 2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``. + 3. Backend registers the connection in ``DeviceConnectionManager``. + 4. Session enters message dispatch loop + heartbeat. + +Incoming frame dispatch: + - ``tool_result`` → resolves a pending tool-call Future. + - ``agent_data`` → enqueued in the per-run agent data queue. + - ``agent_complete`` → sends None sentinel to close the queue stream. + - ``pong`` → heartbeat acknowledgement (updates last-seen). + - unknown types → logged, ignored. + +Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s. + +On disconnect: + - Unregisters from DeviceConnectionManager. + - Marks all in-progress AgentRunLog rows for this user as ``error`` + with message "device disconnected". +""" + +from __future__ import annotations + +import asyncio +import json +import logging + +from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from jose import JWTError, jwt +from sqlalchemy import select, update + +from app.config.settings import settings +from app.core.device_manager import device_manager +from app.db import async_session +from app.models import AgentRunLog +from app.schemas import WsFrameType + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ws", tags=["device-ws"]) + +_HEARTBEAT_INTERVAL = 30 # seconds +_PONG_TIMEOUT = 10 # seconds — grace window after a ping + + +@router.websocket("/device") +async def device_ws(websocket: WebSocket) -> None: + """Persistent WebSocket endpoint for Electron device connections. + + Authentication is via ``?token=`` query parameter. + """ + # ── 1. Authenticate before accepting ───────────────────────────── + token = websocket.query_params.get("token", "") + try: + payload = jwt.decode( + token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] + ) + user_id: str | None = payload.get("sub") + if not user_id: + raise JWTError("missing sub") + except JWTError: + await websocket.close(code=1008) # Policy Violation + return + + await websocket.accept() + + # ── 2. Await device_hello frame ─────────────────────────────────── + try: + raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0) + except (asyncio.TimeoutError, WebSocketDisconnect): + await websocket.close(code=1008) + return + + try: + hello = json.loads(raw) + if hello.get("type") != WsFrameType.device_hello: + raise ValueError("expected device_hello as first frame") + device_id: str = hello["device_id"] + agent_ids: list[str] = hello.get("agent_ids", []) + except (KeyError, ValueError, json.JSONDecodeError) as exc: + logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc) + await websocket.close(code=1008) + return + + # ── 3. Register connection ──────────────────────────────────────── + device_manager.register(user_id, device_id, websocket) + logger.info( + "device_ws: connected user=%s device=%s agents=%s", + user_id, + device_id, + agent_ids, + ) + + # Step 3.4 will replace this stub with a real call to agent_runner. + asyncio.create_task(_trigger_pending_runs_stub(user_id, device_id)) + + # ── 4. Concurrent message loop + heartbeat ──────────────────────── + try: + await asyncio.gather( + _message_loop(websocket, user_id), + _heartbeat_loop(websocket), + ) + except WebSocketDisconnect: + pass + except Exception as exc: + logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc) + finally: + device_manager.unregister(user_id) + logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id) + await _mark_runs_disconnected(user_id) + + +# ── Message dispatch loop ───────────────────────────────────────────── + +async def _message_loop(websocket: WebSocket, user_id: str) -> None: + """Receive frames from Electron and dispatch to the appropriate handler.""" + async for raw in websocket.iter_text(): + try: + frame: dict = json.loads(raw) + except json.JSONDecodeError: + logger.warning("device_ws: invalid JSON from user=%s", user_id) + continue + + frame_type = frame.get("type") + + if frame_type == WsFrameType.tool_result: + call_id = frame.get("id") + if call_id: + device_manager.resolve_pending_call(user_id, call_id, frame) + else: + logger.warning( + "device_ws: tool_result missing id from user=%s", user_id + ) + + elif frame_type == WsFrameType.agent_data: + run_id = frame.get("run_id") + if run_id: + try: + queue = device_manager.get_agent_data_queue(user_id, run_id) + await queue.put(frame) + except RuntimeError: + logger.warning( + "device_ws: agent_data for unknown run user=%s run=%s", + user_id, + run_id, + ) + else: + logger.warning( + "device_ws: agent_data missing run_id from user=%s", user_id + ) + + elif frame_type == WsFrameType.agent_complete: + run_id = frame.get("run_id") + if run_id: + try: + queue = device_manager.get_agent_data_queue(user_id, run_id) + # Sentinel: signals the agent data stream is finished. + await queue.put(None) + except RuntimeError: + pass + else: + logger.warning( + "device_ws: agent_complete missing run_id from user=%s", user_id + ) + + elif frame_type == "pong": + # Heartbeat ack — nothing to do, connection is alive. + pass + + else: + logger.debug( + "device_ws: unknown frame type %r from user=%s", frame_type, user_id + ) + + +# ── Heartbeat ───────────────────────────────────────────────────────── + +async def _heartbeat_loop(websocket: WebSocket) -> None: + """Send a ping frame every 30 s to keep the connection alive.""" + while True: + await asyncio.sleep(_HEARTBEAT_INTERVAL) + await websocket.send_text(json.dumps({"type": "ping"})) + + +# ── Disconnect cleanup ──────────────────────────────────────────────── + +async def _mark_runs_disconnected(user_id: str) -> None: + """Mark all in-progress AgentRunLog rows as 'error' for this user.""" + try: + async with async_session() as db: + await db.execute( + update(AgentRunLog) + .where( + AgentRunLog.user_id == user_id, + AgentRunLog.status == "running", + ) + .values( + status="error", + errors=["device disconnected"], + ) + ) + await db.commit() + except Exception as exc: + logger.error( + "device_ws: failed to mark runs as disconnected for user=%s: %s", + user_id, + exc, + ) + + +# ── Pending-run trigger stub (Step 3.4 will replace) ───────────────── + +async def _trigger_pending_runs_stub(user_id: str, device_id: str) -> None: + """No-op stub. Step 3.4 wires this to agent_runner.trigger_pending_runs.""" + logger.debug( + "device_ws: _trigger_pending_runs stub user=%s device=%s", user_id, device_id + ) diff --git a/app/core/device_manager.py b/app/core/device_manager.py new file mode 100644 index 0000000..62c1ec9 --- /dev/null +++ b/app/core/device_manager.py @@ -0,0 +1,183 @@ +"""Device connection manager. + +Maintains in-memory state for all active Electron → backend WebSocket +connections. One connection per user (latest replaces previous). + +The manager participates in two interaction patterns: + +1. **Tool-call round-trip** (bidirectional CRUD): + - Backend sends ``tool_call`` frame → Electron executes CRUD → returns + ``tool_result`` frame. + - ``create_pending_call`` registers a Future keyed by ``call_id``. + - ``resolve_pending_call`` fulfils the Future; callers awaiting it + receive the result dict from Electron. + +2. **Agent-data streaming** (local directory agent runs): + - Backend sends ``agent_run`` frame → Electron reads files and sends + back a stream of ``agent_data`` frames followed by ``agent_complete``. + - ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for + a specific ``run_id`` so the agent runner can iterate frames. + +The ``device_manager`` module-level singleton is imported by both the +device WS route and the agent runner. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from dataclasses import dataclass, field + +from fastapi import WebSocket + +logger = logging.getLogger(__name__) + + +@dataclass +class DeviceConnection: + """State for a single connected Electron device.""" + + ws: WebSocket + device_id: str + # Futures indexed by tool_call id — resolved when tool_result arrives. + pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict) + # Per-run queues for agent_data / agent_complete frames. + agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict) + + +class DeviceConnectionManager: + """Singleton registry of active Electron WebSocket connections. + + Thread/task safety note: asyncio is single-threaded by design. All + mutations happen inside await-points on the main event loop, so no + locking is required for the in-memory dicts. + """ + + def __init__(self) -> None: + self._connections: dict[str, DeviceConnection] = {} + + # ── Registration ────────────────────────────────────────────────── + + def register(self, user_id: str, device_id: str, ws: WebSocket) -> None: + """Store the active connection for *user_id*, replacing any previous one.""" + if user_id in self._connections: + old = self._connections[user_id] + logger.info( + "device_manager: replacing existing connection for user=%s device=%s", + user_id, + old.device_id, + ) + # Cancel any futures that were waiting on the old connection. + for fut in old.pending_calls.values(): + if not fut.done(): + fut.cancel() + self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id) + logger.info( + "device_manager: registered user=%s device=%s", user_id, device_id + ) + + def unregister(self, user_id: str) -> None: + """Remove the connection for *user_id* and cancel any pending futures.""" + conn = self._connections.pop(user_id, None) + if conn is None: + return + for fut in conn.pending_calls.values(): + if not fut.done(): + fut.cancel() + logger.info("device_manager: unregistered user=%s", user_id) + + # ── Presence queries ────────────────────────────────────────────── + + def get_ws(self, user_id: str) -> WebSocket | None: + """Return the active WebSocket for *user_id*, or ``None`` if offline.""" + conn = self._connections.get(user_id) + return conn.ws if conn else None + + def is_online(self, user_id: str, device_id: str | None = None) -> bool: + """Return ``True`` if the user has an active connection. + + If *device_id* is provided also checks that it matches the connected device. + """ + conn = self._connections.get(user_id) + if conn is None: + return False + if device_id is not None: + return conn.device_id == device_id + return True + + # ── Frame sending ───────────────────────────────────────────────── + + async def send_frame(self, user_id: str, frame: dict) -> None: + """Send *frame* as a JSON text message to the device. + + Raises ``RuntimeError`` if the user is not connected. + """ + conn = self._connections.get(user_id) + if conn is None: + raise RuntimeError( + f"send_frame: user {user_id!r} is not connected" + ) + await conn.ws.send_text(json.dumps(frame)) + + # ── Tool-call round-trip ────────────────────────────────────────── + + def create_pending_call( + self, user_id: str, call_id: str + ) -> asyncio.Future[dict]: + """Register a Future that will be resolved when the tool_result arrives. + + Raises ``RuntimeError`` if the user is not connected. + """ + conn = self._connections.get(user_id) + if conn is None: + raise RuntimeError( + f"create_pending_call: user {user_id!r} is not connected" + ) + loop = asyncio.get_event_loop() + fut: asyncio.Future[dict] = loop.create_future() + conn.pending_calls[call_id] = fut + return fut + + def resolve_pending_call( + self, user_id: str, call_id: str, result: dict + ) -> None: + """Fulfil the Future registered under *call_id* with the Electron result. + + No-ops if the call_id is unknown (already timed out or cancelled). + """ + conn = self._connections.get(user_id) + if conn is None: + return + fut = conn.pending_calls.pop(call_id, None) + if fut is not None and not fut.done(): + fut.set_result(result) + + # ── Agent-data queue ────────────────────────────────────────────── + + def get_agent_data_queue( + self, user_id: str, run_id: str + ) -> asyncio.Queue[dict | None]: + """Return (creating if absent) the queue for *run_id* agent frames. + + The agent runner reads from this queue. The device WS handler writes + to it. ``None`` is the sentinel that signals the stream is finished. + """ + conn = self._connections.get(user_id) + if conn is None: + raise RuntimeError( + f"get_agent_data_queue: user {user_id!r} is not connected" + ) + if run_id not in conn.agent_data_queues: + conn.agent_data_queues[run_id] = asyncio.Queue() + return conn.agent_data_queues[run_id] + + def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None: + """Remove the queue for *run_id* once a run has completed.""" + conn = self._connections.get(user_id) + if conn: + conn.agent_data_queues.pop(run_id, None) + + +# Module-level singleton — import this everywhere. +device_manager = DeviceConnectionManager() diff --git a/app/main.py b/app/main.py index 31a9822..8bec4bb 100644 --- a/app/main.py +++ b/app/main.py @@ -43,17 +43,18 @@ def create_app() -> FastAPI: app.add_middleware(SanitizerMiddleware) app.add_middleware(TierRateLimitMiddleware) - from app.api.routes import agents, auth, backup, billing, chat, plans, plugins, storage, vectors + from app.api.routes import agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors - app.include_router(auth.router, prefix="/api/v1") - app.include_router(chat.router, prefix="/api/v1") - app.include_router(plans.router, prefix="/api/v1") - app.include_router(storage.router, prefix="/api/v1") - app.include_router(vectors.router, prefix="/api/v1") - app.include_router(backup.router, prefix="/api/v1") - app.include_router(plugins.router, prefix="/api/v1") - app.include_router(billing.router, prefix="/api/v1") - app.include_router(agents.router, prefix="/api/v1") + app.include_router(auth.router, prefix="/api/v1") + app.include_router(chat.router, prefix="/api/v1") + app.include_router(plans.router, prefix="/api/v1") + app.include_router(storage.router, prefix="/api/v1") + app.include_router(vectors.router, prefix="/api/v1") + app.include_router(backup.router, prefix="/api/v1") + app.include_router(plugins.router, prefix="/api/v1") + app.include_router(billing.router, prefix="/api/v1") + app.include_router(agents.router, prefix="/api/v1") + app.include_router(device_ws.router, prefix="/api/v1") @app.get("/api/v1/health", tags=["health"]) async def health() -> dict: diff --git a/tests/test_device_ws.py b/tests/test_device_ws.py new file mode 100644 index 0000000..fcabce7 --- /dev/null +++ b/tests/test_device_ws.py @@ -0,0 +1,362 @@ +"""Tests for Step 3.3: DeviceConnectionManager and device WS endpoint. + +Coverage: + Unit tests — DeviceConnectionManager register/unregister/is_online/ + get_ws/send_frame/pending-call round-trip/agent-data queue + Integration — /api/v1/ws/device endpoint via TestClient WebSocket: + auth rejection, happy-path connect, tool_result dispatch, + agent_data queue routing, agent_complete sentinel, disconnect + cleanup (AgentRunLog marked as error) +""" + +from __future__ import annotations + +import asyncio +import json +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio + +from app.core.device_manager import DeviceConnection, DeviceConnectionManager +from app.db import get_session +from app.main import app +from app.models import AgentRunLog +from tests.conftest import TEST_USER_IDS, auth_header, make_jwt + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_FREE_UID = TEST_USER_IDS["free"] +_PRO_UID = TEST_USER_IDS["pro"] + + +def _device_hello(device_id: str = "dev-001", agent_ids: list[str] | None = None) -> str: + return json.dumps( + {"type": "device_hello", "device_id": device_id, "agent_ids": agent_ids or []} + ) + + +# --------------------------------------------------------------------------- +# DB override (shared across integration tests) +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _override_db(db_session): + """Route all get_session calls to the test SQLite session.""" + + async def _gen(): + yield db_session + + app.dependency_overrides[get_session] = _gen + yield + app.dependency_overrides.pop(get_session, None) + + +# --------------------------------------------------------------------------- +# DeviceConnectionManager unit tests +# --------------------------------------------------------------------------- + +@pytest.fixture() +def manager() -> DeviceConnectionManager: + """Fresh manager instance for each test.""" + return DeviceConnectionManager() + + +@pytest.fixture() +def mock_ws() -> MagicMock: + ws = MagicMock() + ws.send_text = AsyncMock() + return ws + + +def test_manager_register_and_is_online(manager, mock_ws): + assert not manager.is_online("user1") + manager.register("user1", "dev-A", mock_ws) + assert manager.is_online("user1") + assert manager.is_online("user1", "dev-A") + assert not manager.is_online("user1", "dev-B") + + +def test_manager_get_ws_returns_none_when_offline(manager): + assert manager.get_ws("no-such-user") is None + + +def test_manager_unregister(manager, mock_ws): + manager.register("user1", "dev-A", mock_ws) + assert manager.is_online("user1") + manager.unregister("user1") + assert not manager.is_online("user1") + assert manager.get_ws("user1") is None + + +def test_manager_unregister_unknown_is_noop(manager): + # Must not raise. + manager.unregister("ghost") + + +def test_manager_replace_connection_cancels_old_futures(manager): + ws_a = MagicMock() + ws_a.send_text = AsyncMock() + ws_b = MagicMock() + ws_b.send_text = AsyncMock() + + # Create event loop context for Future. + loop = asyncio.new_event_loop() + try: + async def _run(): + manager.register("user1", "dev-A", ws_a) + fut = manager.create_pending_call("user1", "call-1") + # Replace connection — old future should be cancelled. + manager.register("user1", "dev-B", ws_b) + assert fut.cancelled() + + loop.run_until_complete(_run()) + finally: + loop.close() + + +@pytest.mark.asyncio +async def test_manager_send_frame(manager, mock_ws): + manager.register("user1", "dev-A", mock_ws) + await manager.send_frame("user1", {"type": "ping"}) + mock_ws.send_text.assert_called_once_with(json.dumps({"type": "ping"})) + + +@pytest.mark.asyncio +async def test_manager_send_frame_raises_when_offline(manager): + with pytest.raises(RuntimeError, match="not connected"): + await manager.send_frame("ghost", {"type": "ping"}) + + +@pytest.mark.asyncio +async def test_manager_pending_call_round_trip(manager, mock_ws): + manager.register("user1", "dev-A", mock_ws) + fut = manager.create_pending_call("user1", "call-42") + result = {"type": "tool_result", "id": "call-42", "rows": [{"id": "row1"}]} + manager.resolve_pending_call("user1", "call-42", result) + assert fut.done() + assert await fut == result + + +@pytest.mark.asyncio +async def test_manager_resolve_unknown_call_is_noop(manager, mock_ws): + manager.register("user1", "dev-A", mock_ws) + # Should not raise. + manager.resolve_pending_call("user1", "no-such-call", {}) + + +@pytest.mark.asyncio +async def test_manager_unregister_cancels_pending_calls(manager, mock_ws): + manager.register("user1", "dev-A", mock_ws) + fut = manager.create_pending_call("user1", "call-1") + manager.unregister("user1") + assert fut.cancelled() + + +@pytest.mark.asyncio +async def test_manager_agent_data_queue(manager, mock_ws): + manager.register("user1", "dev-A", mock_ws) + q = manager.get_agent_data_queue("user1", "run-xyz") + # Put a frame and get it back. + frame = {"type": "agent_data", "run_id": "run-xyz", "files": []} + await q.put(frame) + assert await q.get() == frame + + +@pytest.mark.asyncio +async def test_manager_agent_data_queue_creates_once(manager, mock_ws): + manager.register("user1", "dev-A", mock_ws) + q1 = manager.get_agent_data_queue("user1", "run-1") + q2 = manager.get_agent_data_queue("user1", "run-1") + assert q1 is q2 + + +@pytest.mark.asyncio +async def test_manager_agent_data_queue_raises_when_offline(manager): + with pytest.raises(RuntimeError, match="not connected"): + manager.get_agent_data_queue("ghost", "run-1") + + +@pytest.mark.asyncio +async def test_manager_cleanup_agent_data_queue(manager, mock_ws): + manager.register("user1", "dev-A", mock_ws) + manager.get_agent_data_queue("user1", "run-1") + manager.cleanup_agent_data_queue("user1", "run-1") + # After cleanup a new queue is created (not the same object). + q_new = manager.get_agent_data_queue("user1", "run-1") + assert q_new is not None + + +# --------------------------------------------------------------------------- +# Integration tests — /api/v1/ws/device endpoint +# --------------------------------------------------------------------------- + +def test_ws_device_rejects_without_token(client): + with pytest.raises(Exception): + # TestClient will raise or close when the server rejects. + with client.websocket_connect("/api/v1/ws/device") as ws: + ws.receive_text() + + +def test_ws_device_rejects_invalid_token(client): + with pytest.raises(Exception): + with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws: + ws.receive_text() + + +def test_ws_device_happy_path(client): + """Connect, send device_hello, receive ping, then close.""" + token = make_jwt(tier="free") + + # Patch the heartbeat sleep so the test doesn't block 30 s. + with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.01): + with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: + ws.send_text(_device_hello("dev-001")) + # Next message from server should be a heartbeat ping (interval=0.01s). + msg = ws.receive_text() + data = json.loads(msg) + assert data["type"] == "ping" + # Close gracefully. + ws.close() + + +def test_ws_device_invalid_first_frame_closes(client): + """Non-device_hello first frame should close the connection.""" + token = make_jwt(tier="free") + with pytest.raises(Exception): + with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: + ws.send_text(json.dumps({"type": "chat_request", "message": "hi"})) + ws.receive_text() # server should close after bad frame + + +def test_ws_device_tool_result_dispatched(client): + """tool_result frame is routed to the DeviceConnectionManager.""" + token = make_jwt(tier="free") + user_id = TEST_USER_IDS["free"] + + from app.core.device_manager import device_manager as dm + + captured: list[dict] = [] + + original_resolve = dm.resolve_pending_call + + def _spy(uid, call_id, result): + captured.append({"uid": uid, "call_id": call_id, "result": result}) + original_resolve(uid, call_id, result) + + with patch.object(dm, "resolve_pending_call", side_effect=_spy): + with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999): + with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: + ws.send_text(_device_hello("dev-001")) + # Send a tool_result frame. + ws.send_text( + json.dumps( + { + "type": "tool_result", + "id": "call-123", + "rows": [{"id": "task-1", "title": "Buy milk"}], + } + ) + ) + ws.close() + + assert any(c["call_id"] == "call-123" for c in captured) + + +def test_ws_device_agent_data_enqueued(client): + """agent_data frame is placed in the per-run queue by the message loop.""" + from app.core.device_manager import device_manager as dm + + token = make_jwt(tier="free") + user_id = TEST_USER_IDS["free"] + + # Capture the queue object the message loop accesses. + captured_queue: list[asyncio.Queue] = [] + original_get_queue = dm.get_agent_data_queue + + def _spy_get_queue(uid, run_id): + q = original_get_queue(uid, run_id) + if not captured_queue: + captured_queue.append(q) + return q + + with patch.object(dm, "get_agent_data_queue", side_effect=_spy_get_queue): + with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999): + with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: + ws.send_text(_device_hello("dev-001")) + ws.send_text( + json.dumps( + { + "type": "agent_data", + "run_id": "run-XYZ", + "files": [{"path": "/tmp/file.txt", "content": "hello"}], + } + ) + ) + ws.close() + + # The queue should have received exactly one frame. + assert captured_queue, "queue was never accessed" + assert not captured_queue[0].empty() + + +def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session): + """On disconnect, _mark_runs_disconnected is called with the correct user_id.""" + from app.api.routes import device_ws as _dws + + token = make_jwt(tier="free") + user_id = TEST_USER_IDS["free"] + + cleanup_calls: list[str] = [] + + async def _fake_cleanup(uid: str) -> None: + cleanup_calls.append(uid) + + with patch.object(_dws, "_mark_runs_disconnected", side_effect=_fake_cleanup): + with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999): + with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: + ws.send_text(_device_hello("dev-001")) + ws.close() + + assert user_id in cleanup_calls + + +@pytest.mark.asyncio +async def test_mark_runs_disconnected_updates_db(db_session): + """_mark_runs_disconnected marks in-progress runs as error in the DB.""" + from sqlalchemy import select + + from app.api.routes.device_ws import _mark_runs_disconnected + from tests.conftest import _TestSessionLocal + + user_id = TEST_USER_IDS["free"] + + run_log = AgentRunLog( + id=str(uuid.uuid4()), + agent_id=str(uuid.uuid4()), + agent_type="local", + user_id=user_id, + status="running", + started_at=datetime.now(timezone.utc), + ) + db_session.add(run_log) + await db_session.commit() + + # Route the function to the same test-DB session factory. + with patch("app.api.routes.device_ws.async_session", _TestSessionLocal): + await _mark_runs_disconnected(user_id) + + # Verify through the same session factory. + async with _TestSessionLocal() as s: + result = await s.execute( + select(AgentRunLog).where(AgentRunLog.id == run_log.id) + ) + updated = result.scalar_one_or_none() + + assert updated is not None + assert updated.status == "error" + assert updated.errors and "device disconnected" in updated.errors