"""Device connection manager. Maintains in-memory state for all active Electron → backend WebSocket connections. One connection per user (latest replaces previous). The manager participates in two interaction patterns: 1. **Tool-call round-trip** (bidirectional CRUD): - Backend sends ``tool_call`` frame → Electron executes CRUD → returns ``tool_result`` frame. - ``create_pending_call`` registers a Future keyed by ``call_id``. - ``resolve_pending_call`` fulfils the Future; callers awaiting it receive the result dict from Electron. 2. **Agent-data streaming** (local directory agent runs): - Backend sends ``agent_run`` frame → Electron reads files and sends back a stream of ``agent_data`` frames followed by ``agent_complete``. - ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for a specific ``run_id`` so the agent runner can iterate frames. The ``device_manager`` module-level singleton is imported by both the device WS route and the agent runner. """ from __future__ import annotations import asyncio import json import logging from dataclasses import dataclass, field from fastapi import WebSocket logger = logging.getLogger(__name__) @dataclass class DeviceConnection: """State for a single connected Electron device.""" ws: WebSocket device_id: str # Futures indexed by tool_call id — resolved when tool_result arrives. pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict) # Per-run queues for agent_data / agent_complete frames. agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict) class DeviceConnectionManager: """Singleton registry of active Electron WebSocket connections. Thread/task safety note: asyncio is single-threaded by design. All mutations happen inside await-points on the main event loop, so no locking is required for the in-memory dicts. """ def __init__(self) -> None: self._connections: dict[str, DeviceConnection] = {} # ── Registration ────────────────────────────────────────────────── def register(self, user_id: str, device_id: str, ws: WebSocket) -> None: """Store the active connection for *user_id*, replacing any previous one.""" if user_id in self._connections: old = self._connections[user_id] logger.info( "device_manager: replacing existing connection for user=%s device=%s", user_id, old.device_id, ) # Cancel any futures that were waiting on the old connection. for fut in old.pending_calls.values(): if not fut.done(): fut.cancel() self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id) logger.info( "device_manager: registered user=%s device=%s", user_id, device_id ) def unregister(self, user_id: str) -> None: """Remove the connection for *user_id* and cancel any pending futures.""" conn = self._connections.pop(user_id, None) if conn is None: return for fut in conn.pending_calls.values(): if not fut.done(): fut.cancel() logger.info("device_manager: unregistered user=%s", user_id) # ── Presence queries ────────────────────────────────────────────── def get_ws(self, user_id: str) -> WebSocket | None: """Return the active WebSocket for *user_id*, or ``None`` if offline.""" conn = self._connections.get(user_id) return conn.ws if conn else None def is_online(self, user_id: str, device_id: str | None = None) -> bool: """Return ``True`` if the user has an active connection. If *device_id* is provided also checks that it matches the connected device. """ conn = self._connections.get(user_id) if conn is None: return False if device_id is not None: return conn.device_id == device_id return True # ── Frame sending ───────────────────────────────────────────────── async def send_frame(self, user_id: str, frame: dict) -> None: """Send *frame* as a JSON text message to the device. Raises ``RuntimeError`` if the user is not connected. """ conn = self._connections.get(user_id) if conn is None: raise RuntimeError( f"send_frame: user {user_id!r} is not connected" ) await conn.ws.send_text(json.dumps(frame)) # ── Tool-call round-trip ────────────────────────────────────────── def create_pending_call( self, user_id: str, call_id: str ) -> asyncio.Future[dict]: """Register a Future that will be resolved when the tool_result arrives. Raises ``RuntimeError`` if the user is not connected. """ conn = self._connections.get(user_id) if conn is None: raise RuntimeError( f"create_pending_call: user {user_id!r} is not connected" ) loop = asyncio.get_event_loop() fut: asyncio.Future[dict] = loop.create_future() conn.pending_calls[call_id] = fut return fut def resolve_pending_call( self, user_id: str, call_id: str, result: dict ) -> None: """Fulfil the Future registered under *call_id* with the Electron result. No-ops if the call_id is unknown (already timed out or cancelled). """ conn = self._connections.get(user_id) if conn is None: return fut = conn.pending_calls.pop(call_id, None) if fut is not None and not fut.done(): fut.set_result(result) # ── Agent-data queue ────────────────────────────────────────────── def get_agent_data_queue( self, user_id: str, run_id: str ) -> asyncio.Queue[dict | None]: """Return (creating if absent) the queue for *run_id* agent frames. The agent runner reads from this queue. The device WS handler writes to it. ``None`` is the sentinel that signals the stream is finished. """ conn = self._connections.get(user_id) if conn is None: raise RuntimeError( f"get_agent_data_queue: user {user_id!r} is not connected" ) if run_id not in conn.agent_data_queues: conn.agent_data_queues[run_id] = asyncio.Queue() return conn.agent_data_queues[run_id] def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None: """Remove the queue for *run_id* once a run has completed.""" conn = self._connections.get(user_id) if conn: conn.agent_data_queues.pop(run_id, None) # Module-level singleton — import this everywhere. device_manager = DeviceConnectionManager()