step 3.3 complete: device WS endpoint + DeviceConnectionManager

This commit is contained in:
2026-03-05 15:51:58 +01:00
parent 19ad5be97f
commit 608d6c784f
5 changed files with 784 additions and 12 deletions

View File

@@ -348,7 +348,7 @@ Cloud Agent:
- **Outcome:** Full CRUD for agent configs with tier-gated creation limits. - **Outcome:** Full CRUD for agent configs with tier-gated creation limits.
### Step 3.3 — Device WS endpoint ### Step 3.3 — Device WS endpoint
- [ ] Create `app/api/routes/device_ws.py`: - [x] Create `app/api/routes/device_ws.py`:
- `WebSocket /api/v1/ws/device?token=<jwt>` — persistent connection from Electron - `WebSocket /api/v1/ws/device?token=<jwt>` — persistent connection from Electron
- On connect: - On connect:
- Authenticate JWT - Authenticate JWT
@@ -364,7 +364,7 @@ Cloud Agent:
- Remove from `DeviceConnectionManager` - Remove from `DeviceConnectionManager`
- Mark any in-progress agent runs as `error` with "device disconnected" - Mark any in-progress agent runs as `error` with "device disconnected"
- Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s - Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s
- [ ] Create `app/core/device_manager.py`: - [x] Create `app/core/device_manager.py`:
- `DeviceConnectionManager` (singleton): - `DeviceConnectionManager` (singleton):
- `register(user_id, device_id, ws)` — stores active connection - `register(user_id, device_id, ws)` — stores active connection
- `unregister(user_id)` — removes connection - `unregister(user_id)` — removes connection

226
app/api/routes/device_ws.py Normal file
View File

@@ -0,0 +1,226 @@
"""Device WebSocket endpoint.
Persistent connection from Electron devices to the backend.
WS /api/v1/ws/device?token=<jwt>
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
available during the WebSocket handshake).
Protocol:
1. Client connects → JWT validated → connection accepted.
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
3. Backend registers the connection in ``DeviceConnectionManager``.
4. Session enters message dispatch loop + heartbeat.
Incoming frame dispatch:
- ``tool_result`` → resolves a pending tool-call Future.
- ``agent_data`` → enqueued in the per-run agent data queue.
- ``agent_complete`` → sends None sentinel to close the queue stream.
- ``pong`` → heartbeat acknowledgement (updates last-seen).
- unknown types → logged, ignored.
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
On disconnect:
- Unregisters from DeviceConnectionManager.
- Marks all in-progress AgentRunLog rows for this user as ``error``
with message "device disconnected".
"""
from __future__ import annotations
import asyncio
import json
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from sqlalchemy import select, update
from app.config.settings import settings
from app.core.device_manager import device_manager
from app.db import async_session
from app.models import AgentRunLog
from app.schemas import WsFrameType
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"])
_HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
@router.websocket("/device")
async def device_ws(websocket: WebSocket) -> None:
"""Persistent WebSocket endpoint for Electron device connections.
Authentication is via ``?token=<jwt>`` query parameter.
"""
# ── 1. Authenticate before accepting ─────────────────────────────
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008) # Policy Violation
return
await websocket.accept()
# ── 2. Await device_hello frame ───────────────────────────────────
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
try:
hello = json.loads(raw)
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
await websocket.close(code=1008)
return
# ── 3. Register connection ────────────────────────────────────────
device_manager.register(user_id, device_id, websocket)
logger.info(
"device_ws: connected user=%s device=%s agents=%s",
user_id,
device_id,
agent_ids,
)
# Step 3.4 will replace this stub with a real call to agent_runner.
asyncio.create_task(_trigger_pending_runs_stub(user_id, device_id))
# ── 4. Concurrent message loop + heartbeat ────────────────────────
try:
await asyncio.gather(
_message_loop(websocket, user_id),
_heartbeat_loop(websocket),
)
except WebSocketDisconnect:
pass
except Exception as exc:
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
finally:
device_manager.unregister(user_id)
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
await _mark_runs_disconnected(user_id)
# ── Message dispatch loop ─────────────────────────────────────────────
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
"""Receive frames from Electron and dispatch to the appropriate handler."""
async for raw in websocket.iter_text():
try:
frame: dict = json.loads(raw)
except json.JSONDecodeError:
logger.warning("device_ws: invalid JSON from user=%s", user_id)
continue
frame_type = frame.get("type")
if frame_type == WsFrameType.tool_result:
call_id = frame.get("id")
if call_id:
device_manager.resolve_pending_call(user_id, call_id, frame)
else:
logger.warning(
"device_ws: tool_result missing id from user=%s", user_id
)
elif frame_type == WsFrameType.agent_data:
run_id = frame.get("run_id")
if run_id:
try:
queue = device_manager.get_agent_data_queue(user_id, run_id)
await queue.put(frame)
except RuntimeError:
logger.warning(
"device_ws: agent_data for unknown run user=%s run=%s",
user_id,
run_id,
)
else:
logger.warning(
"device_ws: agent_data missing run_id from user=%s", user_id
)
elif frame_type == WsFrameType.agent_complete:
run_id = frame.get("run_id")
if run_id:
try:
queue = device_manager.get_agent_data_queue(user_id, run_id)
# Sentinel: signals the agent data stream is finished.
await queue.put(None)
except RuntimeError:
pass
else:
logger.warning(
"device_ws: agent_complete missing run_id from user=%s", user_id
)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
else:
logger.debug(
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
)
# ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:
"""Send a ping frame every 30 s to keep the connection alive."""
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"type": "ping"}))
# ── Disconnect cleanup ────────────────────────────────────────────────
async def _mark_runs_disconnected(user_id: str) -> None:
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
try:
async with async_session() as db:
await db.execute(
update(AgentRunLog)
.where(
AgentRunLog.user_id == user_id,
AgentRunLog.status == "running",
)
.values(
status="error",
errors=["device disconnected"],
)
)
await db.commit()
except Exception as exc:
logger.error(
"device_ws: failed to mark runs as disconnected for user=%s: %s",
user_id,
exc,
)
# ── Pending-run trigger stub (Step 3.4 will replace) ─────────────────
async def _trigger_pending_runs_stub(user_id: str, device_id: str) -> None:
"""No-op stub. Step 3.4 wires this to agent_runner.trigger_pending_runs."""
logger.debug(
"device_ws: _trigger_pending_runs stub user=%s device=%s", user_id, device_id
)

183
app/core/device_manager.py Normal file
View File

@@ -0,0 +1,183 @@
"""Device connection manager.
Maintains in-memory state for all active Electron → backend WebSocket
connections. One connection per user (latest replaces previous).
The manager participates in two interaction patterns:
1. **Tool-call round-trip** (bidirectional CRUD):
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
``tool_result`` frame.
- ``create_pending_call`` registers a Future keyed by ``call_id``.
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
receive the result dict from Electron.
2. **Agent-data streaming** (local directory agent runs):
- Backend sends ``agent_run`` frame → Electron reads files and sends
back a stream of ``agent_data`` frames followed by ``agent_complete``.
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
a specific ``run_id`` so the agent runner can iterate frames.
The ``device_manager`` module-level singleton is imported by both the
device WS route and the agent runner.
"""
from __future__ import annotations
import asyncio
import json
import logging
from dataclasses import dataclass, field
from fastapi import WebSocket
logger = logging.getLogger(__name__)
@dataclass
class DeviceConnection:
"""State for a single connected Electron device."""
ws: WebSocket
device_id: str
# Futures indexed by tool_call id — resolved when tool_result arrives.
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
# Per-run queues for agent_data / agent_complete frames.
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
class DeviceConnectionManager:
"""Singleton registry of active Electron WebSocket connections.
Thread/task safety note: asyncio is single-threaded by design. All
mutations happen inside await-points on the main event loop, so no
locking is required for the in-memory dicts.
"""
def __init__(self) -> None:
self._connections: dict[str, DeviceConnection] = {}
# ── Registration ──────────────────────────────────────────────────
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
"""Store the active connection for *user_id*, replacing any previous one."""
if user_id in self._connections:
old = self._connections[user_id]
logger.info(
"device_manager: replacing existing connection for user=%s device=%s",
user_id,
old.device_id,
)
# Cancel any futures that were waiting on the old connection.
for fut in old.pending_calls.values():
if not fut.done():
fut.cancel()
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
logger.info(
"device_manager: registered user=%s device=%s", user_id, device_id
)
def unregister(self, user_id: str) -> None:
"""Remove the connection for *user_id* and cancel any pending futures."""
conn = self._connections.pop(user_id, None)
if conn is None:
return
for fut in conn.pending_calls.values():
if not fut.done():
fut.cancel()
logger.info("device_manager: unregistered user=%s", user_id)
# ── Presence queries ──────────────────────────────────────────────
def get_ws(self, user_id: str) -> WebSocket | None:
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
conn = self._connections.get(user_id)
return conn.ws if conn else None
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
"""Return ``True`` if the user has an active connection.
If *device_id* is provided also checks that it matches the connected device.
"""
conn = self._connections.get(user_id)
if conn is None:
return False
if device_id is not None:
return conn.device_id == device_id
return True
# ── Frame sending ─────────────────────────────────────────────────
async def send_frame(self, user_id: str, frame: dict) -> None:
"""Send *frame* as a JSON text message to the device.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"send_frame: user {user_id!r} is not connected"
)
await conn.ws.send_text(json.dumps(frame))
# ── Tool-call round-trip ──────────────────────────────────────────
def create_pending_call(
self, user_id: str, call_id: str
) -> asyncio.Future[dict]:
"""Register a Future that will be resolved when the tool_result arrives.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"create_pending_call: user {user_id!r} is not connected"
)
loop = asyncio.get_event_loop()
fut: asyncio.Future[dict] = loop.create_future()
conn.pending_calls[call_id] = fut
return fut
def resolve_pending_call(
self, user_id: str, call_id: str, result: dict
) -> None:
"""Fulfil the Future registered under *call_id* with the Electron result.
No-ops if the call_id is unknown (already timed out or cancelled).
"""
conn = self._connections.get(user_id)
if conn is None:
return
fut = conn.pending_calls.pop(call_id, None)
if fut is not None and not fut.done():
fut.set_result(result)
# ── Agent-data queue ──────────────────────────────────────────────
def get_agent_data_queue(
self, user_id: str, run_id: str
) -> asyncio.Queue[dict | None]:
"""Return (creating if absent) the queue for *run_id* agent frames.
The agent runner reads from this queue. The device WS handler writes
to it. ``None`` is the sentinel that signals the stream is finished.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"get_agent_data_queue: user {user_id!r} is not connected"
)
if run_id not in conn.agent_data_queues:
conn.agent_data_queues[run_id] = asyncio.Queue()
return conn.agent_data_queues[run_id]
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
"""Remove the queue for *run_id* once a run has completed."""
conn = self._connections.get(user_id)
if conn:
conn.agent_data_queues.pop(run_id, None)
# Module-level singleton — import this everywhere.
device_manager = DeviceConnectionManager()

View File

@@ -43,17 +43,18 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware) app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware) app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, backup, billing, chat, plans, plugins, storage, vectors from app.api.routes import agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
app.include_router(auth.router, prefix="/api/v1") app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1") app.include_router(chat.router, prefix="/api/v1")
app.include_router(plans.router, prefix="/api/v1") app.include_router(plans.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1") app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1") app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1") app.include_router(backup.router, prefix="/api/v1")
app.include_router(plugins.router, prefix="/api/v1") app.include_router(plugins.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1") app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1") app.include_router(agents.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"]) @app.get("/api/v1/health", tags=["health"])
async def health() -> dict: async def health() -> dict:

362
tests/test_device_ws.py Normal file
View File

@@ -0,0 +1,362 @@
"""Tests for Step 3.3: DeviceConnectionManager and device WS endpoint.
Coverage:
Unit tests — DeviceConnectionManager register/unregister/is_online/
get_ws/send_frame/pending-call round-trip/agent-data queue
Integration — /api/v1/ws/device endpoint via TestClient WebSocket:
auth rejection, happy-path connect, tool_result dispatch,
agent_data queue routing, agent_complete sentinel, disconnect
cleanup (AgentRunLog marked as error)
"""
from __future__ import annotations
import asyncio
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from app.core.device_manager import DeviceConnection, DeviceConnectionManager
from app.db import get_session
from app.main import app
from app.models import AgentRunLog
from tests.conftest import TEST_USER_IDS, auth_header, make_jwt
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_FREE_UID = TEST_USER_IDS["free"]
_PRO_UID = TEST_USER_IDS["pro"]
def _device_hello(device_id: str = "dev-001", agent_ids: list[str] | None = None) -> str:
return json.dumps(
{"type": "device_hello", "device_id": device_id, "agent_ids": agent_ids or []}
)
# ---------------------------------------------------------------------------
# DB override (shared across integration tests)
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _override_db(db_session):
"""Route all get_session calls to the test SQLite session."""
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ---------------------------------------------------------------------------
# DeviceConnectionManager unit tests
# ---------------------------------------------------------------------------
@pytest.fixture()
def manager() -> DeviceConnectionManager:
"""Fresh manager instance for each test."""
return DeviceConnectionManager()
@pytest.fixture()
def mock_ws() -> MagicMock:
ws = MagicMock()
ws.send_text = AsyncMock()
return ws
def test_manager_register_and_is_online(manager, mock_ws):
assert not manager.is_online("user1")
manager.register("user1", "dev-A", mock_ws)
assert manager.is_online("user1")
assert manager.is_online("user1", "dev-A")
assert not manager.is_online("user1", "dev-B")
def test_manager_get_ws_returns_none_when_offline(manager):
assert manager.get_ws("no-such-user") is None
def test_manager_unregister(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
assert manager.is_online("user1")
manager.unregister("user1")
assert not manager.is_online("user1")
assert manager.get_ws("user1") is None
def test_manager_unregister_unknown_is_noop(manager):
# Must not raise.
manager.unregister("ghost")
def test_manager_replace_connection_cancels_old_futures(manager):
ws_a = MagicMock()
ws_a.send_text = AsyncMock()
ws_b = MagicMock()
ws_b.send_text = AsyncMock()
# Create event loop context for Future.
loop = asyncio.new_event_loop()
try:
async def _run():
manager.register("user1", "dev-A", ws_a)
fut = manager.create_pending_call("user1", "call-1")
# Replace connection — old future should be cancelled.
manager.register("user1", "dev-B", ws_b)
assert fut.cancelled()
loop.run_until_complete(_run())
finally:
loop.close()
@pytest.mark.asyncio
async def test_manager_send_frame(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
await manager.send_frame("user1", {"type": "ping"})
mock_ws.send_text.assert_called_once_with(json.dumps({"type": "ping"}))
@pytest.mark.asyncio
async def test_manager_send_frame_raises_when_offline(manager):
with pytest.raises(RuntimeError, match="not connected"):
await manager.send_frame("ghost", {"type": "ping"})
@pytest.mark.asyncio
async def test_manager_pending_call_round_trip(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
fut = manager.create_pending_call("user1", "call-42")
result = {"type": "tool_result", "id": "call-42", "rows": [{"id": "row1"}]}
manager.resolve_pending_call("user1", "call-42", result)
assert fut.done()
assert await fut == result
@pytest.mark.asyncio
async def test_manager_resolve_unknown_call_is_noop(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
# Should not raise.
manager.resolve_pending_call("user1", "no-such-call", {})
@pytest.mark.asyncio
async def test_manager_unregister_cancels_pending_calls(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
fut = manager.create_pending_call("user1", "call-1")
manager.unregister("user1")
assert fut.cancelled()
@pytest.mark.asyncio
async def test_manager_agent_data_queue(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
q = manager.get_agent_data_queue("user1", "run-xyz")
# Put a frame and get it back.
frame = {"type": "agent_data", "run_id": "run-xyz", "files": []}
await q.put(frame)
assert await q.get() == frame
@pytest.mark.asyncio
async def test_manager_agent_data_queue_creates_once(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
q1 = manager.get_agent_data_queue("user1", "run-1")
q2 = manager.get_agent_data_queue("user1", "run-1")
assert q1 is q2
@pytest.mark.asyncio
async def test_manager_agent_data_queue_raises_when_offline(manager):
with pytest.raises(RuntimeError, match="not connected"):
manager.get_agent_data_queue("ghost", "run-1")
@pytest.mark.asyncio
async def test_manager_cleanup_agent_data_queue(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
manager.get_agent_data_queue("user1", "run-1")
manager.cleanup_agent_data_queue("user1", "run-1")
# After cleanup a new queue is created (not the same object).
q_new = manager.get_agent_data_queue("user1", "run-1")
assert q_new is not None
# ---------------------------------------------------------------------------
# Integration tests — /api/v1/ws/device endpoint
# ---------------------------------------------------------------------------
def test_ws_device_rejects_without_token(client):
with pytest.raises(Exception):
# TestClient will raise or close when the server rejects.
with client.websocket_connect("/api/v1/ws/device") as ws:
ws.receive_text()
def test_ws_device_rejects_invalid_token(client):
with pytest.raises(Exception):
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
ws.receive_text()
def test_ws_device_happy_path(client):
"""Connect, send device_hello, receive ping, then close."""
token = make_jwt(tier="free")
# Patch the heartbeat sleep so the test doesn't block 30 s.
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.01):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(_device_hello("dev-001"))
# Next message from server should be a heartbeat ping (interval=0.01s).
msg = ws.receive_text()
data = json.loads(msg)
assert data["type"] == "ping"
# Close gracefully.
ws.close()
def test_ws_device_invalid_first_frame_closes(client):
"""Non-device_hello first frame should close the connection."""
token = make_jwt(tier="free")
with pytest.raises(Exception):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({"type": "chat_request", "message": "hi"}))
ws.receive_text() # server should close after bad frame
def test_ws_device_tool_result_dispatched(client):
"""tool_result frame is routed to the DeviceConnectionManager."""
token = make_jwt(tier="free")
user_id = TEST_USER_IDS["free"]
from app.core.device_manager import device_manager as dm
captured: list[dict] = []
original_resolve = dm.resolve_pending_call
def _spy(uid, call_id, result):
captured.append({"uid": uid, "call_id": call_id, "result": result})
original_resolve(uid, call_id, result)
with patch.object(dm, "resolve_pending_call", side_effect=_spy):
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(_device_hello("dev-001"))
# Send a tool_result frame.
ws.send_text(
json.dumps(
{
"type": "tool_result",
"id": "call-123",
"rows": [{"id": "task-1", "title": "Buy milk"}],
}
)
)
ws.close()
assert any(c["call_id"] == "call-123" for c in captured)
def test_ws_device_agent_data_enqueued(client):
"""agent_data frame is placed in the per-run queue by the message loop."""
from app.core.device_manager import device_manager as dm
token = make_jwt(tier="free")
user_id = TEST_USER_IDS["free"]
# Capture the queue object the message loop accesses.
captured_queue: list[asyncio.Queue] = []
original_get_queue = dm.get_agent_data_queue
def _spy_get_queue(uid, run_id):
q = original_get_queue(uid, run_id)
if not captured_queue:
captured_queue.append(q)
return q
with patch.object(dm, "get_agent_data_queue", side_effect=_spy_get_queue):
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(_device_hello("dev-001"))
ws.send_text(
json.dumps(
{
"type": "agent_data",
"run_id": "run-XYZ",
"files": [{"path": "/tmp/file.txt", "content": "hello"}],
}
)
)
ws.close()
# The queue should have received exactly one frame.
assert captured_queue, "queue was never accessed"
assert not captured_queue[0].empty()
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
"""On disconnect, _mark_runs_disconnected is called with the correct user_id."""
from app.api.routes import device_ws as _dws
token = make_jwt(tier="free")
user_id = TEST_USER_IDS["free"]
cleanup_calls: list[str] = []
async def _fake_cleanup(uid: str) -> None:
cleanup_calls.append(uid)
with patch.object(_dws, "_mark_runs_disconnected", side_effect=_fake_cleanup):
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(_device_hello("dev-001"))
ws.close()
assert user_id in cleanup_calls
@pytest.mark.asyncio
async def test_mark_runs_disconnected_updates_db(db_session):
"""_mark_runs_disconnected marks in-progress runs as error in the DB."""
from sqlalchemy import select
from app.api.routes.device_ws import _mark_runs_disconnected
from tests.conftest import _TestSessionLocal
user_id = TEST_USER_IDS["free"]
run_log = AgentRunLog(
id=str(uuid.uuid4()),
agent_id=str(uuid.uuid4()),
agent_type="local",
user_id=user_id,
status="running",
started_at=datetime.now(timezone.utc),
)
db_session.add(run_log)
await db_session.commit()
# Route the function to the same test-DB session factory.
with patch("app.api.routes.device_ws.async_session", _TestSessionLocal):
await _mark_runs_disconnected(user_id)
# Verify through the same session factory.
async with _TestSessionLocal() as s:
result = await s.execute(
select(AgentRunLog).where(AgentRunLog.id == run_log.id)
)
updated = result.scalar_one_or_none()
assert updated is not None
assert updated.status == "error"
assert updated.errors and "device disconnected" in updated.errors