222 lines
8.0 KiB
Python
222 lines
8.0 KiB
Python
"""Device WebSocket endpoint.
|
|
|
|
Persistent connection from Electron devices to the backend.
|
|
|
|
WS /api/v1/ws/device?token=<jwt>
|
|
|
|
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
|
available during the WebSocket handshake).
|
|
|
|
Protocol:
|
|
1. Client connects → JWT validated → connection accepted.
|
|
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
|
3. Backend registers the connection in ``DeviceConnectionManager``.
|
|
4. Session enters message dispatch loop + heartbeat.
|
|
|
|
Incoming frame dispatch:
|
|
- ``tool_result`` → resolves a pending tool-call Future.
|
|
- ``agent_data`` → enqueued in the per-run agent data queue.
|
|
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
|
- unknown types → logged, ignored.
|
|
|
|
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
|
|
|
On disconnect:
|
|
- Unregisters from DeviceConnectionManager.
|
|
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
|
with message "device disconnected".
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
from jose import JWTError, jwt
|
|
from sqlalchemy import select, update
|
|
|
|
from app.config.settings import settings
|
|
from app.core.agent_runner import trigger_pending_runs
|
|
from app.core.device_manager import device_manager
|
|
from app.db import async_session
|
|
from app.models import AgentRunLog
|
|
from app.schemas import WsFrameType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
|
|
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
|
|
|
|
|
@router.websocket("/device")
|
|
async def device_ws(websocket: WebSocket) -> None:
|
|
"""Persistent WebSocket endpoint for Electron device connections.
|
|
|
|
Authentication is via ``?token=<jwt>`` query parameter.
|
|
"""
|
|
# ── 1. Authenticate before accepting ─────────────────────────────
|
|
token = websocket.query_params.get("token", "")
|
|
try:
|
|
payload = jwt.decode(
|
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
user_id: str | None = payload.get("sub")
|
|
if not user_id:
|
|
raise JWTError("missing sub")
|
|
except JWTError:
|
|
await websocket.close(code=1008) # Policy Violation
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
# ── 2. Await device_hello frame ───────────────────────────────────
|
|
try:
|
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
|
except (asyncio.TimeoutError, WebSocketDisconnect):
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
try:
|
|
hello = json.loads(raw)
|
|
if hello.get("type") != WsFrameType.device_hello:
|
|
raise ValueError("expected device_hello as first frame")
|
|
device_id: str = hello["device_id"]
|
|
agent_ids: list[str] = hello.get("agent_ids", [])
|
|
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
|
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
# ── 3. Register connection ────────────────────────────────────────
|
|
device_manager.register(user_id, device_id, websocket)
|
|
logger.info(
|
|
"device_ws: connected user=%s device=%s agents=%s",
|
|
user_id,
|
|
device_id,
|
|
agent_ids,
|
|
)
|
|
|
|
# Trigger any overdue agent runs now that the device is connected.
|
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
|
|
|
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
|
try:
|
|
await asyncio.gather(
|
|
_message_loop(websocket, user_id),
|
|
_heartbeat_loop(websocket),
|
|
)
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as exc:
|
|
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
|
finally:
|
|
device_manager.unregister(user_id)
|
|
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
|
await _mark_runs_disconnected(user_id)
|
|
|
|
|
|
# ── Message dispatch loop ─────────────────────────────────────────────
|
|
|
|
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
|
async for raw in websocket.iter_text():
|
|
try:
|
|
frame: dict = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
|
continue
|
|
|
|
frame_type = frame.get("type")
|
|
|
|
if frame_type == WsFrameType.tool_result:
|
|
call_id = frame.get("id")
|
|
if call_id:
|
|
device_manager.resolve_pending_call(user_id, call_id, frame)
|
|
else:
|
|
logger.warning(
|
|
"device_ws: tool_result missing id from user=%s", user_id
|
|
)
|
|
|
|
elif frame_type == WsFrameType.agent_data:
|
|
run_id = frame.get("run_id")
|
|
if run_id:
|
|
try:
|
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
await queue.put(frame)
|
|
except RuntimeError:
|
|
logger.warning(
|
|
"device_ws: agent_data for unknown run user=%s run=%s",
|
|
user_id,
|
|
run_id,
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"device_ws: agent_data missing run_id from user=%s", user_id
|
|
)
|
|
|
|
elif frame_type == WsFrameType.agent_complete:
|
|
run_id = frame.get("run_id")
|
|
if run_id:
|
|
try:
|
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
# Sentinel: signals the agent data stream is finished.
|
|
await queue.put(None)
|
|
except RuntimeError:
|
|
pass
|
|
else:
|
|
logger.warning(
|
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
|
)
|
|
|
|
elif frame_type == "pong":
|
|
# Heartbeat ack — nothing to do, connection is alive.
|
|
pass
|
|
|
|
else:
|
|
logger.debug(
|
|
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
|
)
|
|
|
|
|
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
|
|
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
|
"""Send a ping frame every 30 s to keep the connection alive."""
|
|
while True:
|
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
await websocket.send_text(json.dumps({"type": "ping"}))
|
|
|
|
|
|
# ── Disconnect cleanup ────────────────────────────────────────────────
|
|
|
|
async def _mark_runs_disconnected(user_id: str) -> None:
|
|
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
|
try:
|
|
async with async_session() as db:
|
|
await db.execute(
|
|
update(AgentRunLog)
|
|
.where(
|
|
AgentRunLog.user_id == user_id,
|
|
AgentRunLog.status == "running",
|
|
)
|
|
.values(
|
|
status="error",
|
|
errors=["device disconnected"],
|
|
)
|
|
)
|
|
await db.commit()
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
|
user_id,
|
|
exc,
|
|
)
|
|
|
|
|
|
|