step 3.3 complete: device WS endpoint + DeviceConnectionManager
This commit is contained in:
226
app/api/routes/device_ws.py
Normal file
226
app/api/routes/device_ws.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Device WebSocket endpoint.
|
||||
|
||||
Persistent connection from Electron devices to the backend.
|
||||
|
||||
WS /api/v1/ws/device?token=<jwt>
|
||||
|
||||
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
||||
available during the WebSocket handshake).
|
||||
|
||||
Protocol:
|
||||
1. Client connects → JWT validated → connection accepted.
|
||||
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
||||
3. Backend registers the connection in ``DeviceConnectionManager``.
|
||||
4. Session enters message dispatch loop + heartbeat.
|
||||
|
||||
Incoming frame dispatch:
|
||||
- ``tool_result`` → resolves a pending tool-call Future.
|
||||
- ``agent_data`` → enqueued in the per-run agent data queue.
|
||||
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||
- unknown types → logged, ignored.
|
||||
|
||||
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
||||
|
||||
On disconnect:
|
||||
- Unregisters from DeviceConnectionManager.
|
||||
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
||||
with message "device disconnected".
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.core.device_manager import device_manager
|
||||
from app.db import async_session
|
||||
from app.models import AgentRunLog
|
||||
from app.schemas import WsFrameType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||
|
||||
|
||||
@router.websocket("/device")
|
||||
async def device_ws(websocket: WebSocket) -> None:
|
||||
"""Persistent WebSocket endpoint for Electron device connections.
|
||||
|
||||
Authentication is via ``?token=<jwt>`` query parameter.
|
||||
"""
|
||||
# ── 1. Authenticate before accepting ─────────────────────────────
|
||||
token = websocket.query_params.get("token", "")
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
if not user_id:
|
||||
raise JWTError("missing sub")
|
||||
except JWTError:
|
||||
await websocket.close(code=1008) # Policy Violation
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# ── 2. Await device_hello frame ───────────────────────────────────
|
||||
try:
|
||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
try:
|
||||
hello = json.loads(raw)
|
||||
if hello.get("type") != WsFrameType.device_hello:
|
||||
raise ValueError("expected device_hello as first frame")
|
||||
device_id: str = hello["device_id"]
|
||||
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# ── 3. Register connection ────────────────────────────────────────
|
||||
device_manager.register(user_id, device_id, websocket)
|
||||
logger.info(
|
||||
"device_ws: connected user=%s device=%s agents=%s",
|
||||
user_id,
|
||||
device_id,
|
||||
agent_ids,
|
||||
)
|
||||
|
||||
# Step 3.4 will replace this stub with a real call to agent_runner.
|
||||
asyncio.create_task(_trigger_pending_runs_stub(user_id, device_id))
|
||||
|
||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_message_loop(websocket, user_id),
|
||||
_heartbeat_loop(websocket),
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
||||
finally:
|
||||
device_manager.unregister(user_id)
|
||||
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
||||
await _mark_runs_disconnected(user_id)
|
||||
|
||||
|
||||
# ── Message dispatch loop ─────────────────────────────────────────────
|
||||
|
||||
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
||||
async for raw in websocket.iter_text():
|
||||
try:
|
||||
frame: dict = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
||||
continue
|
||||
|
||||
frame_type = frame.get("type")
|
||||
|
||||
if frame_type == WsFrameType.tool_result:
|
||||
call_id = frame.get("id")
|
||||
if call_id:
|
||||
device_manager.resolve_pending_call(user_id, call_id, frame)
|
||||
else:
|
||||
logger.warning(
|
||||
"device_ws: tool_result missing id from user=%s", user_id
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.agent_data:
|
||||
run_id = frame.get("run_id")
|
||||
if run_id:
|
||||
try:
|
||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||
await queue.put(frame)
|
||||
except RuntimeError:
|
||||
logger.warning(
|
||||
"device_ws: agent_data for unknown run user=%s run=%s",
|
||||
user_id,
|
||||
run_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"device_ws: agent_data missing run_id from user=%s", user_id
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.agent_complete:
|
||||
run_id = frame.get("run_id")
|
||||
if run_id:
|
||||
try:
|
||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||
# Sentinel: signals the agent data stream is finished.
|
||||
await queue.put(None)
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
logger.warning(
|
||||
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||
)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
||||
)
|
||||
|
||||
|
||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
"""Send a ping frame every 30 s to keep the connection alive."""
|
||||
while True:
|
||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||
|
||||
|
||||
# ── Disconnect cleanup ────────────────────────────────────────────────
|
||||
|
||||
async def _mark_runs_disconnected(user_id: str) -> None:
|
||||
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
||||
try:
|
||||
async with async_session() as db:
|
||||
await db.execute(
|
||||
update(AgentRunLog)
|
||||
.where(
|
||||
AgentRunLog.user_id == user_id,
|
||||
AgentRunLog.status == "running",
|
||||
)
|
||||
.values(
|
||||
status="error",
|
||||
errors=["device disconnected"],
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
|
||||
# ── Pending-run trigger stub (Step 3.4 will replace) ─────────────────
|
||||
|
||||
async def _trigger_pending_runs_stub(user_id: str, device_id: str) -> None:
|
||||
"""No-op stub. Step 3.4 wires this to agent_runner.trigger_pending_runs."""
|
||||
logger.debug(
|
||||
"device_ws: _trigger_pending_runs stub user=%s device=%s", user_id, device_id
|
||||
)
|
||||
183
app/core/device_manager.py
Normal file
183
app/core/device_manager.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Device connection manager.
|
||||
|
||||
Maintains in-memory state for all active Electron → backend WebSocket
|
||||
connections. One connection per user (latest replaces previous).
|
||||
|
||||
The manager participates in two interaction patterns:
|
||||
|
||||
1. **Tool-call round-trip** (bidirectional CRUD):
|
||||
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
||||
``tool_result`` frame.
|
||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||
receive the result dict from Electron.
|
||||
|
||||
2. **Agent-data streaming** (local directory agent runs):
|
||||
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
||||
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
||||
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
||||
a specific ``run_id`` so the agent runner can iterate frames.
|
||||
|
||||
The ``device_manager`` module-level singleton is imported by both the
|
||||
device WS route and the agent runner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceConnection:
|
||||
"""State for a single connected Electron device."""
|
||||
|
||||
ws: WebSocket
|
||||
device_id: str
|
||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||
# Per-run queues for agent_data / agent_complete frames.
|
||||
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DeviceConnectionManager:
|
||||
"""Singleton registry of active Electron WebSocket connections.
|
||||
|
||||
Thread/task safety note: asyncio is single-threaded by design. All
|
||||
mutations happen inside await-points on the main event loop, so no
|
||||
locking is required for the in-memory dicts.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connections: dict[str, DeviceConnection] = {}
|
||||
|
||||
# ── Registration ──────────────────────────────────────────────────
|
||||
|
||||
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
|
||||
"""Store the active connection for *user_id*, replacing any previous one."""
|
||||
if user_id in self._connections:
|
||||
old = self._connections[user_id]
|
||||
logger.info(
|
||||
"device_manager: replacing existing connection for user=%s device=%s",
|
||||
user_id,
|
||||
old.device_id,
|
||||
)
|
||||
# Cancel any futures that were waiting on the old connection.
|
||||
for fut in old.pending_calls.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
|
||||
logger.info(
|
||||
"device_manager: registered user=%s device=%s", user_id, device_id
|
||||
)
|
||||
|
||||
def unregister(self, user_id: str) -> None:
|
||||
"""Remove the connection for *user_id* and cancel any pending futures."""
|
||||
conn = self._connections.pop(user_id, None)
|
||||
if conn is None:
|
||||
return
|
||||
for fut in conn.pending_calls.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
logger.info("device_manager: unregistered user=%s", user_id)
|
||||
|
||||
# ── Presence queries ──────────────────────────────────────────────
|
||||
|
||||
def get_ws(self, user_id: str) -> WebSocket | None:
|
||||
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
|
||||
conn = self._connections.get(user_id)
|
||||
return conn.ws if conn else None
|
||||
|
||||
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
|
||||
"""Return ``True`` if the user has an active connection.
|
||||
|
||||
If *device_id* is provided also checks that it matches the connected device.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
return False
|
||||
if device_id is not None:
|
||||
return conn.device_id == device_id
|
||||
return True
|
||||
|
||||
# ── Frame sending ─────────────────────────────────────────────────
|
||||
|
||||
async def send_frame(self, user_id: str, frame: dict) -> None:
|
||||
"""Send *frame* as a JSON text message to the device.
|
||||
|
||||
Raises ``RuntimeError`` if the user is not connected.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
raise RuntimeError(
|
||||
f"send_frame: user {user_id!r} is not connected"
|
||||
)
|
||||
await conn.ws.send_text(json.dumps(frame))
|
||||
|
||||
# ── Tool-call round-trip ──────────────────────────────────────────
|
||||
|
||||
def create_pending_call(
|
||||
self, user_id: str, call_id: str
|
||||
) -> asyncio.Future[dict]:
|
||||
"""Register a Future that will be resolved when the tool_result arrives.
|
||||
|
||||
Raises ``RuntimeError`` if the user is not connected.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
raise RuntimeError(
|
||||
f"create_pending_call: user {user_id!r} is not connected"
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
fut: asyncio.Future[dict] = loop.create_future()
|
||||
conn.pending_calls[call_id] = fut
|
||||
return fut
|
||||
|
||||
def resolve_pending_call(
|
||||
self, user_id: str, call_id: str, result: dict
|
||||
) -> None:
|
||||
"""Fulfil the Future registered under *call_id* with the Electron result.
|
||||
|
||||
No-ops if the call_id is unknown (already timed out or cancelled).
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
return
|
||||
fut = conn.pending_calls.pop(call_id, None)
|
||||
if fut is not None and not fut.done():
|
||||
fut.set_result(result)
|
||||
|
||||
# ── Agent-data queue ──────────────────────────────────────────────
|
||||
|
||||
def get_agent_data_queue(
|
||||
self, user_id: str, run_id: str
|
||||
) -> asyncio.Queue[dict | None]:
|
||||
"""Return (creating if absent) the queue for *run_id* agent frames.
|
||||
|
||||
The agent runner reads from this queue. The device WS handler writes
|
||||
to it. ``None`` is the sentinel that signals the stream is finished.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
raise RuntimeError(
|
||||
f"get_agent_data_queue: user {user_id!r} is not connected"
|
||||
)
|
||||
if run_id not in conn.agent_data_queues:
|
||||
conn.agent_data_queues[run_id] = asyncio.Queue()
|
||||
return conn.agent_data_queues[run_id]
|
||||
|
||||
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
||||
"""Remove the queue for *run_id* once a run has completed."""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn:
|
||||
conn.agent_data_queues.pop(run_id, None)
|
||||
|
||||
|
||||
# Module-level singleton — import this everywhere.
|
||||
device_manager = DeviceConnectionManager()
|
||||
21
app/main.py
21
app/main.py
@@ -43,17 +43,18 @@ def create_app() -> FastAPI:
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import agents, auth, backup, billing, chat, plans, plugins, storage, vectors
|
||||
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(plans.router, prefix="/api/v1")
|
||||
app.include_router(storage.router, prefix="/api/v1")
|
||||
app.include_router(vectors.router, prefix="/api/v1")
|
||||
app.include_router(backup.router, prefix="/api/v1")
|
||||
app.include_router(plugins.router, prefix="/api/v1")
|
||||
app.include_router(billing.router, prefix="/api/v1")
|
||||
app.include_router(agents.router, prefix="/api/v1")
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(plans.router, prefix="/api/v1")
|
||||
app.include_router(storage.router, prefix="/api/v1")
|
||||
app.include_router(vectors.router, prefix="/api/v1")
|
||||
app.include_router(backup.router, prefix="/api/v1")
|
||||
app.include_router(plugins.router, prefix="/api/v1")
|
||||
app.include_router(billing.router, prefix="/api/v1")
|
||||
app.include_router(agents.router, prefix="/api/v1")
|
||||
app.include_router(device_ws.router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
|
||||
Reference in New Issue
Block a user