"""Device WebSocket endpoint. Persistent connection from Electron devices to the backend. WS /api/v1/ws/device?token= Auth: JWT passed as ``?token=`` query parameter (Bearer header is not available during the WebSocket handshake). Protocol: 1. Client connects → JWT validated → connection accepted. 2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``. 3. Backend registers the connection in ``DeviceConnectionManager``. 4. Session enters message dispatch loop + heartbeat. Incoming frame dispatch: - ``tool_result`` → resolves a pending tool-call Future. - ``agent_data`` → enqueued in the per-run agent data queue. - ``agent_complete`` → sends None sentinel to close the queue stream. - ``pong`` → heartbeat acknowledgement (updates last-seen). - unknown types → logged, ignored. Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s. On disconnect: - Unregisters from DeviceConnectionManager. - Marks all in-progress AgentRunLog rows for this user as ``error`` with message "device disconnected". """ from __future__ import annotations import asyncio import json import logging from fastapi import APIRouter, WebSocket, WebSocketDisconnect from jose import JWTError, jwt from sqlalchemy import select, update from app.config.settings import settings from app.core.agent_runner import trigger_pending_runs from app.core.device_manager import device_manager from app.db import async_session from app.models import AgentRunLog from app.schemas import WsFrameType logger = logging.getLogger(__name__) router = APIRouter(prefix="/ws", tags=["device-ws"]) _HEARTBEAT_INTERVAL = 30 # seconds _PONG_TIMEOUT = 10 # seconds — grace window after a ping @router.websocket("/device") async def device_ws(websocket: WebSocket) -> None: """Persistent WebSocket endpoint for Electron device connections. Authentication is via ``?token=`` query parameter. """ # ── 1. Authenticate before accepting ───────────────────────────── token = websocket.query_params.get("token", "") try: payload = jwt.decode( token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] ) user_id: str | None = payload.get("sub") if not user_id: raise JWTError("missing sub") except JWTError: await websocket.close(code=1008) # Policy Violation return await websocket.accept() # ── 2. Await device_hello frame ─────────────────────────────────── try: raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0) except (asyncio.TimeoutError, WebSocketDisconnect): await websocket.close(code=1008) return try: hello = json.loads(raw) if hello.get("type") != WsFrameType.device_hello: raise ValueError("expected device_hello as first frame") device_id: str = hello["device_id"] agent_ids: list[str] = hello.get("agent_ids", []) except (KeyError, ValueError, json.JSONDecodeError) as exc: logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc) await websocket.close(code=1008) return # ── 3. Register connection ──────────────────────────────────────── device_manager.register(user_id, device_id, websocket) logger.info( "device_ws: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids, ) # Trigger any overdue agent runs now that the device is connected. asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager)) # ── 4. Concurrent message loop + heartbeat ──────────────────────── try: await asyncio.gather( _message_loop(websocket, user_id), _heartbeat_loop(websocket), ) except WebSocketDisconnect: pass except Exception as exc: logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc) finally: device_manager.unregister(user_id) logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id) await _mark_runs_disconnected(user_id) # ── Message dispatch loop ───────────────────────────────────────────── async def _message_loop(websocket: WebSocket, user_id: str) -> None: """Receive frames from Electron and dispatch to the appropriate handler.""" async for raw in websocket.iter_text(): try: frame: dict = json.loads(raw) except json.JSONDecodeError: logger.warning("device_ws: invalid JSON from user=%s", user_id) continue frame_type = frame.get("type") if frame_type == WsFrameType.tool_result: call_id = frame.get("id") if call_id: device_manager.resolve_pending_call(user_id, call_id, frame) else: logger.warning( "device_ws: tool_result missing id from user=%s", user_id ) elif frame_type == WsFrameType.agent_data: run_id = frame.get("run_id") if run_id: try: queue = device_manager.get_agent_data_queue(user_id, run_id) await queue.put(frame) except RuntimeError: logger.warning( "device_ws: agent_data for unknown run user=%s run=%s", user_id, run_id, ) else: logger.warning( "device_ws: agent_data missing run_id from user=%s", user_id ) elif frame_type == WsFrameType.agent_complete: run_id = frame.get("run_id") if run_id: try: queue = device_manager.get_agent_data_queue(user_id, run_id) # Sentinel: signals the agent data stream is finished. await queue.put(None) except RuntimeError: pass else: logger.warning( "device_ws: agent_complete missing run_id from user=%s", user_id ) elif frame_type == "pong": # Heartbeat ack — nothing to do, connection is alive. pass else: logger.debug( "device_ws: unknown frame type %r from user=%s", frame_type, user_id ) # ── Heartbeat ───────────────────────────────────────────────────────── async def _heartbeat_loop(websocket: WebSocket) -> None: """Send a ping frame every 30 s to keep the connection alive.""" while True: await asyncio.sleep(_HEARTBEAT_INTERVAL) await websocket.send_text(json.dumps({"type": "ping"})) # ── Disconnect cleanup ──────────────────────────────────────────────── async def _mark_runs_disconnected(user_id: str) -> None: """Mark all in-progress AgentRunLog rows as 'error' for this user.""" try: async with async_session() as db: await db.execute( update(AgentRunLog) .where( AgentRunLog.user_id == user_id, AgentRunLog.status == "running", ) .values( status="error", errors=["device disconnected"], ) ) await db.commit() except Exception as exc: logger.error( "device_ws: failed to mark runs as disconnected for user=%s: %s", user_id, exc, )