WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)
Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
174 lines
6.4 KiB
Python
174 lines
6.4 KiB
Python
"""WebSocket handler — device connection lifecycle.
|
|
|
|
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
|
|
and runs two concurrent loops:
|
|
1. Message loop: receive frames from Electron, route to Redis
|
|
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
|
|
3. Heartbeat loop: ping every 30s
|
|
|
|
No business logic lives here — the handler is a JSON frame router.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from uuid import uuid4
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
from jose import JWTError, jwt
|
|
|
|
from shared.config import settings
|
|
from shared.schemas import WsFrameType
|
|
|
|
from app.redis_bridge import (
|
|
publish_batch_request,
|
|
publish_chat_request,
|
|
push_tool_result,
|
|
register_device,
|
|
set_gateway_id,
|
|
subscribe_outbound,
|
|
unregister_device,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
|
|
|
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
|
|
# Set a unique gateway instance ID on module load
|
|
set_gateway_id(str(uuid4()))
|
|
|
|
|
|
@router.websocket("/device")
|
|
async def device_ws(websocket: WebSocket) -> None:
|
|
"""Persistent WebSocket endpoint for Electron device connections."""
|
|
|
|
# ── 1. Authenticate via ?token= query parameter ──────────────────
|
|
token = websocket.query_params.get("token", "")
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.JWT_PUBLIC_KEY,
|
|
algorithms=["RS256"],
|
|
)
|
|
user_id: str | None = payload.get("sub")
|
|
email: str | None = payload.get("email")
|
|
if not user_id:
|
|
raise JWTError("missing sub")
|
|
except JWTError:
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
# ── 2. Await device_hello frame ──────────────────────────────────
|
|
try:
|
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
|
except (asyncio.TimeoutError, WebSocketDisconnect):
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
try:
|
|
hello = json.loads(raw)
|
|
if hello.get("type") != WsFrameType.device_hello:
|
|
raise ValueError("expected device_hello as first frame")
|
|
device_id: str = hello["device_id"]
|
|
agent_ids: list[str] = hello.get("agent_ids", [])
|
|
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
|
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
# ── 3. Register device in Redis ──────────────────────────────────
|
|
await register_device(user_id, device_id)
|
|
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
|
|
|
|
# Notify downstream services that device is online (for agent trigger)
|
|
await publish_batch_request(user_id, {
|
|
"type": "device_online",
|
|
"user_id": user_id,
|
|
"device_id": device_id,
|
|
"agent_ids": agent_ids,
|
|
})
|
|
|
|
# ── 4. Subscribe to outbound Redis channel ───────────────────────
|
|
pubsub = await subscribe_outbound(user_id)
|
|
|
|
# ── 5. Run concurrent loops ──────────────────────────────────────
|
|
try:
|
|
await asyncio.gather(
|
|
_inbound_loop(websocket, user_id),
|
|
_outbound_loop(websocket, pubsub),
|
|
_heartbeat_loop(websocket),
|
|
)
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as exc:
|
|
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
|
|
finally:
|
|
await pubsub.unsubscribe()
|
|
await pubsub.aclose()
|
|
await unregister_device(user_id)
|
|
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
|
|
|
|
|
|
# ── Inbound: Electron → Redis ────────────────────────────────────────
|
|
|
|
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
|
|
"""Receive frames from Electron and route to the appropriate Redis channel."""
|
|
async for raw in websocket.iter_text():
|
|
try:
|
|
frame: dict = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
logger.warning("handler: invalid JSON from user=%s", user_id)
|
|
continue
|
|
|
|
frame_type = frame.get("type")
|
|
|
|
# Inject user_id so downstream services know who sent it
|
|
frame["user_id"] = user_id
|
|
|
|
if frame_type == WsFrameType.tool_result:
|
|
call_id = frame.get("id")
|
|
if call_id:
|
|
await push_tool_result(call_id, frame)
|
|
else:
|
|
logger.warning("handler: tool_result missing id user=%s", user_id)
|
|
|
|
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
|
|
await publish_chat_request(user_id, frame)
|
|
|
|
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
|
|
await publish_batch_request(user_id, frame)
|
|
|
|
elif frame_type == "pong":
|
|
pass # heartbeat ack
|
|
|
|
else:
|
|
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
|
|
|
|
|
|
# ── Outbound: Redis → Electron ───────────────────────────────────────
|
|
|
|
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
|
|
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
|
|
while True:
|
|
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
|
if message is not None and message["type"] == "message":
|
|
await websocket.send_text(message["data"])
|
|
else:
|
|
# Brief sleep to avoid busy-wait when no messages
|
|
await asyncio.sleep(0.01)
|
|
|
|
|
|
# ── Heartbeat ────────────────────────────────────────────────────────
|
|
|
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
|
"""Send ping frames every 30s to keep the connection alive."""
|
|
while True:
|
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
await websocket.send_text(json.dumps({"type": "ping"}))
|