step 3.3 complete: device WS endpoint + DeviceConnectionManager
This commit is contained in:
@@ -348,7 +348,7 @@ Cloud Agent:
|
|||||||
- **Outcome:** Full CRUD for agent configs with tier-gated creation limits.
|
- **Outcome:** Full CRUD for agent configs with tier-gated creation limits.
|
||||||
|
|
||||||
### Step 3.3 — Device WS endpoint
|
### Step 3.3 — Device WS endpoint
|
||||||
- [ ] Create `app/api/routes/device_ws.py`:
|
- [x] Create `app/api/routes/device_ws.py`:
|
||||||
- `WebSocket /api/v1/ws/device?token=<jwt>` — persistent connection from Electron
|
- `WebSocket /api/v1/ws/device?token=<jwt>` — persistent connection from Electron
|
||||||
- On connect:
|
- On connect:
|
||||||
- Authenticate JWT
|
- Authenticate JWT
|
||||||
@@ -364,7 +364,7 @@ Cloud Agent:
|
|||||||
- Remove from `DeviceConnectionManager`
|
- Remove from `DeviceConnectionManager`
|
||||||
- Mark any in-progress agent runs as `error` with "device disconnected"
|
- Mark any in-progress agent runs as `error` with "device disconnected"
|
||||||
- Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s
|
- Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s
|
||||||
- [ ] Create `app/core/device_manager.py`:
|
- [x] Create `app/core/device_manager.py`:
|
||||||
- `DeviceConnectionManager` (singleton):
|
- `DeviceConnectionManager` (singleton):
|
||||||
- `register(user_id, device_id, ws)` — stores active connection
|
- `register(user_id, device_id, ws)` — stores active connection
|
||||||
- `unregister(user_id)` — removes connection
|
- `unregister(user_id)` — removes connection
|
||||||
|
|||||||
226
app/api/routes/device_ws.py
Normal file
226
app/api/routes/device_ws.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
"""Device WebSocket endpoint.
|
||||||
|
|
||||||
|
Persistent connection from Electron devices to the backend.
|
||||||
|
|
||||||
|
WS /api/v1/ws/device?token=<jwt>
|
||||||
|
|
||||||
|
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
||||||
|
available during the WebSocket handshake).
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
1. Client connects → JWT validated → connection accepted.
|
||||||
|
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
||||||
|
3. Backend registers the connection in ``DeviceConnectionManager``.
|
||||||
|
4. Session enters message dispatch loop + heartbeat.
|
||||||
|
|
||||||
|
Incoming frame dispatch:
|
||||||
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
|
- ``agent_data`` → enqueued in the per-run agent data queue.
|
||||||
|
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
||||||
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
|
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
||||||
|
|
||||||
|
On disconnect:
|
||||||
|
- Unregisters from DeviceConnectionManager.
|
||||||
|
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
||||||
|
with message "device disconnected".
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.device_manager import device_manager
|
||||||
|
from app.db import async_session
|
||||||
|
from app.models import AgentRunLog
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||||
|
|
||||||
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/device")
|
||||||
|
async def device_ws(websocket: WebSocket) -> None:
|
||||||
|
"""Persistent WebSocket endpoint for Electron device connections.
|
||||||
|
|
||||||
|
Authentication is via ``?token=<jwt>`` query parameter.
|
||||||
|
"""
|
||||||
|
# ── 1. Authenticate before accepting ─────────────────────────────
|
||||||
|
token = websocket.query_params.get("token", "")
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
if not user_id:
|
||||||
|
raise JWTError("missing sub")
|
||||||
|
except JWTError:
|
||||||
|
await websocket.close(code=1008) # Policy Violation
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
# ── 2. Await device_hello frame ───────────────────────────────────
|
||||||
|
try:
|
||||||
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||||
|
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
hello = json.loads(raw)
|
||||||
|
if hello.get("type") != WsFrameType.device_hello:
|
||||||
|
raise ValueError("expected device_hello as first frame")
|
||||||
|
device_id: str = hello["device_id"]
|
||||||
|
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||||
|
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||||
|
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Register connection ────────────────────────────────────────
|
||||||
|
device_manager.register(user_id, device_id, websocket)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: connected user=%s device=%s agents=%s",
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
agent_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3.4 will replace this stub with a real call to agent_runner.
|
||||||
|
asyncio.create_task(_trigger_pending_runs_stub(user_id, device_id))
|
||||||
|
|
||||||
|
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||||
|
try:
|
||||||
|
await asyncio.gather(
|
||||||
|
_message_loop(websocket, user_id),
|
||||||
|
_heartbeat_loop(websocket),
|
||||||
|
)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
||||||
|
finally:
|
||||||
|
device_manager.unregister(user_id)
|
||||||
|
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
||||||
|
await _mark_runs_disconnected(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Message dispatch loop ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||||
|
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
||||||
|
async for raw in websocket.iter_text():
|
||||||
|
try:
|
||||||
|
frame: dict = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame_type = frame.get("type")
|
||||||
|
|
||||||
|
if frame_type == WsFrameType.tool_result:
|
||||||
|
call_id = frame.get("id")
|
||||||
|
if call_id:
|
||||||
|
device_manager.resolve_pending_call(user_id, call_id, frame)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_data:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
await queue.put(frame)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data for unknown run user=%s run=%s",
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_complete:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
# Sentinel: signals the agent data stream is finished.
|
||||||
|
await queue.put(None)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == "pong":
|
||||||
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
|
"""Send a ping frame every 30 s to keep the connection alive."""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||||
|
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Disconnect cleanup ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _mark_runs_disconnected(user_id: str) -> None:
|
||||||
|
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
await db.execute(
|
||||||
|
update(AgentRunLog)
|
||||||
|
.where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.status == "running",
|
||||||
|
)
|
||||||
|
.values(
|
||||||
|
status="error",
|
||||||
|
errors=["device disconnected"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
||||||
|
user_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pending-run trigger stub (Step 3.4 will replace) ─────────────────
|
||||||
|
|
||||||
|
async def _trigger_pending_runs_stub(user_id: str, device_id: str) -> None:
|
||||||
|
"""No-op stub. Step 3.4 wires this to agent_runner.trigger_pending_runs."""
|
||||||
|
logger.debug(
|
||||||
|
"device_ws: _trigger_pending_runs stub user=%s device=%s", user_id, device_id
|
||||||
|
)
|
||||||
183
app/core/device_manager.py
Normal file
183
app/core/device_manager.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Device connection manager.
|
||||||
|
|
||||||
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
|
The manager participates in two interaction patterns:
|
||||||
|
|
||||||
|
1. **Tool-call round-trip** (bidirectional CRUD):
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
||||||
|
``tool_result`` frame.
|
||||||
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
|
receive the result dict from Electron.
|
||||||
|
|
||||||
|
2. **Agent-data streaming** (local directory agent runs):
|
||||||
|
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
||||||
|
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
||||||
|
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
||||||
|
a specific ``run_id`` so the agent runner can iterate frames.
|
||||||
|
|
||||||
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
|
device WS route and the agent runner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeviceConnection:
|
||||||
|
"""State for a single connected Electron device."""
|
||||||
|
|
||||||
|
ws: WebSocket
|
||||||
|
device_id: str
|
||||||
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
# Per-run queues for agent_data / agent_complete frames.
|
||||||
|
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceConnectionManager:
|
||||||
|
"""Singleton registry of active Electron WebSocket connections.
|
||||||
|
|
||||||
|
Thread/task safety note: asyncio is single-threaded by design. All
|
||||||
|
mutations happen inside await-points on the main event loop, so no
|
||||||
|
locking is required for the in-memory dicts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._connections: dict[str, DeviceConnection] = {}
|
||||||
|
|
||||||
|
# ── Registration ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
|
||||||
|
"""Store the active connection for *user_id*, replacing any previous one."""
|
||||||
|
if user_id in self._connections:
|
||||||
|
old = self._connections[user_id]
|
||||||
|
logger.info(
|
||||||
|
"device_manager: replacing existing connection for user=%s device=%s",
|
||||||
|
user_id,
|
||||||
|
old.device_id,
|
||||||
|
)
|
||||||
|
# Cancel any futures that were waiting on the old connection.
|
||||||
|
for fut in old.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
|
||||||
|
logger.info(
|
||||||
|
"device_manager: registered user=%s device=%s", user_id, device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def unregister(self, user_id: str) -> None:
|
||||||
|
"""Remove the connection for *user_id* and cancel any pending futures."""
|
||||||
|
conn = self._connections.pop(user_id, None)
|
||||||
|
if conn is None:
|
||||||
|
return
|
||||||
|
for fut in conn.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
logger.info("device_manager: unregistered user=%s", user_id)
|
||||||
|
|
||||||
|
# ── Presence queries ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_ws(self, user_id: str) -> WebSocket | None:
|
||||||
|
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
return conn.ws if conn else None
|
||||||
|
|
||||||
|
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
|
||||||
|
"""Return ``True`` if the user has an active connection.
|
||||||
|
|
||||||
|
If *device_id* is provided also checks that it matches the connected device.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
return False
|
||||||
|
if device_id is not None:
|
||||||
|
return conn.device_id == device_id
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ── Frame sending ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def send_frame(self, user_id: str, frame: dict) -> None:
|
||||||
|
"""Send *frame* as a JSON text message to the device.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if the user is not connected.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"send_frame: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
await conn.ws.send_text(json.dumps(frame))
|
||||||
|
|
||||||
|
# ── Tool-call round-trip ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_pending_call(
|
||||||
|
self, user_id: str, call_id: str
|
||||||
|
) -> asyncio.Future[dict]:
|
||||||
|
"""Register a Future that will be resolved when the tool_result arrives.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if the user is not connected.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"create_pending_call: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut: asyncio.Future[dict] = loop.create_future()
|
||||||
|
conn.pending_calls[call_id] = fut
|
||||||
|
return fut
|
||||||
|
|
||||||
|
def resolve_pending_call(
|
||||||
|
self, user_id: str, call_id: str, result: dict
|
||||||
|
) -> None:
|
||||||
|
"""Fulfil the Future registered under *call_id* with the Electron result.
|
||||||
|
|
||||||
|
No-ops if the call_id is unknown (already timed out or cancelled).
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
return
|
||||||
|
fut = conn.pending_calls.pop(call_id, None)
|
||||||
|
if fut is not None and not fut.done():
|
||||||
|
fut.set_result(result)
|
||||||
|
|
||||||
|
# ── Agent-data queue ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_agent_data_queue(
|
||||||
|
self, user_id: str, run_id: str
|
||||||
|
) -> asyncio.Queue[dict | None]:
|
||||||
|
"""Return (creating if absent) the queue for *run_id* agent frames.
|
||||||
|
|
||||||
|
The agent runner reads from this queue. The device WS handler writes
|
||||||
|
to it. ``None`` is the sentinel that signals the stream is finished.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"get_agent_data_queue: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
if run_id not in conn.agent_data_queues:
|
||||||
|
conn.agent_data_queues[run_id] = asyncio.Queue()
|
||||||
|
return conn.agent_data_queues[run_id]
|
||||||
|
|
||||||
|
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
||||||
|
"""Remove the queue for *run_id* once a run has completed."""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.agent_data_queues.pop(run_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton — import this everywhere.
|
||||||
|
device_manager = DeviceConnectionManager()
|
||||||
@@ -43,7 +43,7 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agents, auth, backup, billing, chat, plans, plugins, storage, vectors
|
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
@@ -54,6 +54,7 @@ def create_app() -> FastAPI:
|
|||||||
app.include_router(plugins.router, prefix="/api/v1")
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
async def health() -> dict:
|
async def health() -> dict:
|
||||||
|
|||||||
362
tests/test_device_ws.py
Normal file
362
tests/test_device_ws.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
"""Tests for Step 3.3: DeviceConnectionManager and device WS endpoint.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
Unit tests — DeviceConnectionManager register/unregister/is_online/
|
||||||
|
get_ws/send_frame/pending-call round-trip/agent-data queue
|
||||||
|
Integration — /api/v1/ws/device endpoint via TestClient WebSocket:
|
||||||
|
auth rejection, happy-path connect, tool_result dispatch,
|
||||||
|
agent_data queue routing, agent_complete sentinel, disconnect
|
||||||
|
cleanup (AgentRunLog marked as error)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.core.device_manager import DeviceConnection, DeviceConnectionManager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import AgentRunLog
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header, make_jwt
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_FREE_UID = TEST_USER_IDS["free"]
|
||||||
|
_PRO_UID = TEST_USER_IDS["pro"]
|
||||||
|
|
||||||
|
|
||||||
|
def _device_hello(device_id: str = "dev-001", agent_ids: list[str] | None = None) -> str:
|
||||||
|
return json.dumps(
|
||||||
|
{"type": "device_hello", "device_id": device_id, "agent_ids": agent_ids or []}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DB override (shared across integration tests)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
"""Route all get_session calls to the test SQLite session."""
|
||||||
|
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DeviceConnectionManager unit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def manager() -> DeviceConnectionManager:
|
||||||
|
"""Fresh manager instance for each test."""
|
||||||
|
return DeviceConnectionManager()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_ws() -> MagicMock:
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_register_and_is_online(manager, mock_ws):
|
||||||
|
assert not manager.is_online("user1")
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
assert manager.is_online("user1")
|
||||||
|
assert manager.is_online("user1", "dev-A")
|
||||||
|
assert not manager.is_online("user1", "dev-B")
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_get_ws_returns_none_when_offline(manager):
|
||||||
|
assert manager.get_ws("no-such-user") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_unregister(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
assert manager.is_online("user1")
|
||||||
|
manager.unregister("user1")
|
||||||
|
assert not manager.is_online("user1")
|
||||||
|
assert manager.get_ws("user1") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_unregister_unknown_is_noop(manager):
|
||||||
|
# Must not raise.
|
||||||
|
manager.unregister("ghost")
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_replace_connection_cancels_old_futures(manager):
|
||||||
|
ws_a = MagicMock()
|
||||||
|
ws_a.send_text = AsyncMock()
|
||||||
|
ws_b = MagicMock()
|
||||||
|
ws_b.send_text = AsyncMock()
|
||||||
|
|
||||||
|
# Create event loop context for Future.
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
async def _run():
|
||||||
|
manager.register("user1", "dev-A", ws_a)
|
||||||
|
fut = manager.create_pending_call("user1", "call-1")
|
||||||
|
# Replace connection — old future should be cancelled.
|
||||||
|
manager.register("user1", "dev-B", ws_b)
|
||||||
|
assert fut.cancelled()
|
||||||
|
|
||||||
|
loop.run_until_complete(_run())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_send_frame(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
await manager.send_frame("user1", {"type": "ping"})
|
||||||
|
mock_ws.send_text.assert_called_once_with(json.dumps({"type": "ping"}))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_send_frame_raises_when_offline(manager):
|
||||||
|
with pytest.raises(RuntimeError, match="not connected"):
|
||||||
|
await manager.send_frame("ghost", {"type": "ping"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_pending_call_round_trip(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
fut = manager.create_pending_call("user1", "call-42")
|
||||||
|
result = {"type": "tool_result", "id": "call-42", "rows": [{"id": "row1"}]}
|
||||||
|
manager.resolve_pending_call("user1", "call-42", result)
|
||||||
|
assert fut.done()
|
||||||
|
assert await fut == result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_resolve_unknown_call_is_noop(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
# Should not raise.
|
||||||
|
manager.resolve_pending_call("user1", "no-such-call", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_unregister_cancels_pending_calls(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
fut = manager.create_pending_call("user1", "call-1")
|
||||||
|
manager.unregister("user1")
|
||||||
|
assert fut.cancelled()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_agent_data_queue(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
q = manager.get_agent_data_queue("user1", "run-xyz")
|
||||||
|
# Put a frame and get it back.
|
||||||
|
frame = {"type": "agent_data", "run_id": "run-xyz", "files": []}
|
||||||
|
await q.put(frame)
|
||||||
|
assert await q.get() == frame
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_agent_data_queue_creates_once(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
q1 = manager.get_agent_data_queue("user1", "run-1")
|
||||||
|
q2 = manager.get_agent_data_queue("user1", "run-1")
|
||||||
|
assert q1 is q2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_agent_data_queue_raises_when_offline(manager):
|
||||||
|
with pytest.raises(RuntimeError, match="not connected"):
|
||||||
|
manager.get_agent_data_queue("ghost", "run-1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_cleanup_agent_data_queue(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
manager.get_agent_data_queue("user1", "run-1")
|
||||||
|
manager.cleanup_agent_data_queue("user1", "run-1")
|
||||||
|
# After cleanup a new queue is created (not the same object).
|
||||||
|
q_new = manager.get_agent_data_queue("user1", "run-1")
|
||||||
|
assert q_new is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests — /api/v1/ws/device endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_ws_device_rejects_without_token(client):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
# TestClient will raise or close when the server rejects.
|
||||||
|
with client.websocket_connect("/api/v1/ws/device") as ws:
|
||||||
|
ws.receive_text()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_rejects_invalid_token(client):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
|
||||||
|
ws.receive_text()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_happy_path(client):
|
||||||
|
"""Connect, send device_hello, receive ping, then close."""
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
|
||||||
|
# Patch the heartbeat sleep so the test doesn't block 30 s.
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.01):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(_device_hello("dev-001"))
|
||||||
|
# Next message from server should be a heartbeat ping (interval=0.01s).
|
||||||
|
msg = ws.receive_text()
|
||||||
|
data = json.loads(msg)
|
||||||
|
assert data["type"] == "ping"
|
||||||
|
# Close gracefully.
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_invalid_first_frame_closes(client):
|
||||||
|
"""Non-device_hello first frame should close the connection."""
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({"type": "chat_request", "message": "hi"}))
|
||||||
|
ws.receive_text() # server should close after bad frame
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_tool_result_dispatched(client):
|
||||||
|
"""tool_result frame is routed to the DeviceConnectionManager."""
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
user_id = TEST_USER_IDS["free"]
|
||||||
|
|
||||||
|
from app.core.device_manager import device_manager as dm
|
||||||
|
|
||||||
|
captured: list[dict] = []
|
||||||
|
|
||||||
|
original_resolve = dm.resolve_pending_call
|
||||||
|
|
||||||
|
def _spy(uid, call_id, result):
|
||||||
|
captured.append({"uid": uid, "call_id": call_id, "result": result})
|
||||||
|
original_resolve(uid, call_id, result)
|
||||||
|
|
||||||
|
with patch.object(dm, "resolve_pending_call", side_effect=_spy):
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(_device_hello("dev-001"))
|
||||||
|
# Send a tool_result frame.
|
||||||
|
ws.send_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"id": "call-123",
|
||||||
|
"rows": [{"id": "task-1", "title": "Buy milk"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
assert any(c["call_id"] == "call-123" for c in captured)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_agent_data_enqueued(client):
|
||||||
|
"""agent_data frame is placed in the per-run queue by the message loop."""
|
||||||
|
from app.core.device_manager import device_manager as dm
|
||||||
|
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
user_id = TEST_USER_IDS["free"]
|
||||||
|
|
||||||
|
# Capture the queue object the message loop accesses.
|
||||||
|
captured_queue: list[asyncio.Queue] = []
|
||||||
|
original_get_queue = dm.get_agent_data_queue
|
||||||
|
|
||||||
|
def _spy_get_queue(uid, run_id):
|
||||||
|
q = original_get_queue(uid, run_id)
|
||||||
|
if not captured_queue:
|
||||||
|
captured_queue.append(q)
|
||||||
|
return q
|
||||||
|
|
||||||
|
with patch.object(dm, "get_agent_data_queue", side_effect=_spy_get_queue):
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(_device_hello("dev-001"))
|
||||||
|
ws.send_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "agent_data",
|
||||||
|
"run_id": "run-XYZ",
|
||||||
|
"files": [{"path": "/tmp/file.txt", "content": "hello"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
# The queue should have received exactly one frame.
|
||||||
|
assert captured_queue, "queue was never accessed"
|
||||||
|
assert not captured_queue[0].empty()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
|
||||||
|
"""On disconnect, _mark_runs_disconnected is called with the correct user_id."""
|
||||||
|
from app.api.routes import device_ws as _dws
|
||||||
|
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
user_id = TEST_USER_IDS["free"]
|
||||||
|
|
||||||
|
cleanup_calls: list[str] = []
|
||||||
|
|
||||||
|
async def _fake_cleanup(uid: str) -> None:
|
||||||
|
cleanup_calls.append(uid)
|
||||||
|
|
||||||
|
with patch.object(_dws, "_mark_runs_disconnected", side_effect=_fake_cleanup):
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(_device_hello("dev-001"))
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
assert user_id in cleanup_calls
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mark_runs_disconnected_updates_db(db_session):
|
||||||
|
"""_mark_runs_disconnected marks in-progress runs as error in the DB."""
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.api.routes.device_ws import _mark_runs_disconnected
|
||||||
|
from tests.conftest import _TestSessionLocal
|
||||||
|
|
||||||
|
user_id = TEST_USER_IDS["free"]
|
||||||
|
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=str(uuid.uuid4()),
|
||||||
|
agent_type="local",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(run_log)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Route the function to the same test-DB session factory.
|
||||||
|
with patch("app.api.routes.device_ws.async_session", _TestSessionLocal):
|
||||||
|
await _mark_runs_disconnected(user_id)
|
||||||
|
|
||||||
|
# Verify through the same session factory.
|
||||||
|
async with _TestSessionLocal() as s:
|
||||||
|
result = await s.execute(
|
||||||
|
select(AgentRunLog).where(AgentRunLog.id == run_log.id)
|
||||||
|
)
|
||||||
|
updated = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.status == "error"
|
||||||
|
assert updated.errors and "device disconnected" in updated.errors
|
||||||
Reference in New Issue
Block a user