184 lines
7.2 KiB
Python
184 lines
7.2 KiB
Python
"""Device connection manager.
|
|
|
|
Maintains in-memory state for all active Electron → backend WebSocket
|
|
connections. One connection per user (latest replaces previous).
|
|
|
|
The manager participates in two interaction patterns:
|
|
|
|
1. **Tool-call round-trip** (bidirectional CRUD):
|
|
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
|
``tool_result`` frame.
|
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
|
receive the result dict from Electron.
|
|
|
|
2. **Agent-data streaming** (local directory agent runs):
|
|
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
|
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
|
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
|
a specific ``run_id`` so the agent runner can iterate frames.
|
|
|
|
The ``device_manager`` module-level singleton is imported by both the
|
|
device WS route and the agent runner.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
|
|
from fastapi import WebSocket
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class DeviceConnection:
|
|
"""State for a single connected Electron device."""
|
|
|
|
ws: WebSocket
|
|
device_id: str
|
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
|
# Per-run queues for agent_data / agent_complete frames.
|
|
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
|
|
|
|
|
class DeviceConnectionManager:
|
|
"""Singleton registry of active Electron WebSocket connections.
|
|
|
|
Thread/task safety note: asyncio is single-threaded by design. All
|
|
mutations happen inside await-points on the main event loop, so no
|
|
locking is required for the in-memory dicts.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._connections: dict[str, DeviceConnection] = {}
|
|
|
|
# ── Registration ──────────────────────────────────────────────────
|
|
|
|
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
|
|
"""Store the active connection for *user_id*, replacing any previous one."""
|
|
if user_id in self._connections:
|
|
old = self._connections[user_id]
|
|
logger.info(
|
|
"device_manager: replacing existing connection for user=%s device=%s",
|
|
user_id,
|
|
old.device_id,
|
|
)
|
|
# Cancel any futures that were waiting on the old connection.
|
|
for fut in old.pending_calls.values():
|
|
if not fut.done():
|
|
fut.cancel()
|
|
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
|
|
logger.info(
|
|
"device_manager: registered user=%s device=%s", user_id, device_id
|
|
)
|
|
|
|
def unregister(self, user_id: str) -> None:
|
|
"""Remove the connection for *user_id* and cancel any pending futures."""
|
|
conn = self._connections.pop(user_id, None)
|
|
if conn is None:
|
|
return
|
|
for fut in conn.pending_calls.values():
|
|
if not fut.done():
|
|
fut.cancel()
|
|
logger.info("device_manager: unregistered user=%s", user_id)
|
|
|
|
# ── Presence queries ──────────────────────────────────────────────
|
|
|
|
def get_ws(self, user_id: str) -> WebSocket | None:
|
|
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
|
|
conn = self._connections.get(user_id)
|
|
return conn.ws if conn else None
|
|
|
|
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
|
|
"""Return ``True`` if the user has an active connection.
|
|
|
|
If *device_id* is provided also checks that it matches the connected device.
|
|
"""
|
|
conn = self._connections.get(user_id)
|
|
if conn is None:
|
|
return False
|
|
if device_id is not None:
|
|
return conn.device_id == device_id
|
|
return True
|
|
|
|
# ── Frame sending ─────────────────────────────────────────────────
|
|
|
|
async def send_frame(self, user_id: str, frame: dict) -> None:
|
|
"""Send *frame* as a JSON text message to the device.
|
|
|
|
Raises ``RuntimeError`` if the user is not connected.
|
|
"""
|
|
conn = self._connections.get(user_id)
|
|
if conn is None:
|
|
raise RuntimeError(
|
|
f"send_frame: user {user_id!r} is not connected"
|
|
)
|
|
await conn.ws.send_text(json.dumps(frame))
|
|
|
|
# ── Tool-call round-trip ──────────────────────────────────────────
|
|
|
|
def create_pending_call(
|
|
self, user_id: str, call_id: str
|
|
) -> asyncio.Future[dict]:
|
|
"""Register a Future that will be resolved when the tool_result arrives.
|
|
|
|
Raises ``RuntimeError`` if the user is not connected.
|
|
"""
|
|
conn = self._connections.get(user_id)
|
|
if conn is None:
|
|
raise RuntimeError(
|
|
f"create_pending_call: user {user_id!r} is not connected"
|
|
)
|
|
loop = asyncio.get_event_loop()
|
|
fut: asyncio.Future[dict] = loop.create_future()
|
|
conn.pending_calls[call_id] = fut
|
|
return fut
|
|
|
|
def resolve_pending_call(
|
|
self, user_id: str, call_id: str, result: dict
|
|
) -> None:
|
|
"""Fulfil the Future registered under *call_id* with the Electron result.
|
|
|
|
No-ops if the call_id is unknown (already timed out or cancelled).
|
|
"""
|
|
conn = self._connections.get(user_id)
|
|
if conn is None:
|
|
return
|
|
fut = conn.pending_calls.pop(call_id, None)
|
|
if fut is not None and not fut.done():
|
|
fut.set_result(result)
|
|
|
|
# ── Agent-data queue ──────────────────────────────────────────────
|
|
|
|
def get_agent_data_queue(
|
|
self, user_id: str, run_id: str
|
|
) -> asyncio.Queue[dict | None]:
|
|
"""Return (creating if absent) the queue for *run_id* agent frames.
|
|
|
|
The agent runner reads from this queue. The device WS handler writes
|
|
to it. ``None`` is the sentinel that signals the stream is finished.
|
|
"""
|
|
conn = self._connections.get(user_id)
|
|
if conn is None:
|
|
raise RuntimeError(
|
|
f"get_agent_data_queue: user {user_id!r} is not connected"
|
|
)
|
|
if run_id not in conn.agent_data_queues:
|
|
conn.agent_data_queues[run_id] = asyncio.Queue()
|
|
return conn.agent_data_queues[run_id]
|
|
|
|
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
|
"""Remove the queue for *run_id* once a run has completed."""
|
|
conn = self._connections.get(user_id)
|
|
if conn:
|
|
conn.agent_data_queues.pop(run_id, None)
|
|
|
|
|
|
# Module-level singleton — import this everywhere.
|
|
device_manager = DeviceConnectionManager()
|