Files
api/services/ws-gateway/app/handler.py
Roberto Musso 90018af311 feat: add WS Gateway and Chat Service (Step 2)
WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)

Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
2026-03-22 01:20:11 +01:00

174 lines
6.4 KiB
Python

"""WebSocket handler — device connection lifecycle.
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
and runs two concurrent loops:
1. Message loop: receive frames from Electron, route to Redis
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
3. Heartbeat loop: ping every 30s
No business logic lives here — the handler is a JSON frame router.
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from shared.config import settings
from shared.schemas import WsFrameType
from app.redis_bridge import (
publish_batch_request,
publish_chat_request,
push_tool_result,
register_device,
set_gateway_id,
subscribe_outbound,
unregister_device,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
_HEARTBEAT_INTERVAL = 30 # seconds
# Set a unique gateway instance ID on module load
set_gateway_id(str(uuid4()))
@router.websocket("/device")
async def device_ws(websocket: WebSocket) -> None:
"""Persistent WebSocket endpoint for Electron device connections."""
# ── 1. Authenticate via ?token= query parameter ──────────────────
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(
token,
settings.JWT_PUBLIC_KEY,
algorithms=["RS256"],
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008)
return
await websocket.accept()
# ── 2. Await device_hello frame ──────────────────────────────────
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
try:
hello = json.loads(raw)
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
await websocket.close(code=1008)
return
# ── 3. Register device in Redis ──────────────────────────────────
await register_device(user_id, device_id)
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
# Notify downstream services that device is online (for agent trigger)
await publish_batch_request(user_id, {
"type": "device_online",
"user_id": user_id,
"device_id": device_id,
"agent_ids": agent_ids,
})
# ── 4. Subscribe to outbound Redis channel ───────────────────────
pubsub = await subscribe_outbound(user_id)
# ── 5. Run concurrent loops ──────────────────────────────────────
try:
await asyncio.gather(
_inbound_loop(websocket, user_id),
_outbound_loop(websocket, pubsub),
_heartbeat_loop(websocket),
)
except WebSocketDisconnect:
pass
except Exception as exc:
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
finally:
await pubsub.unsubscribe()
await pubsub.aclose()
await unregister_device(user_id)
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
# ── Inbound: Electron → Redis ────────────────────────────────────────
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
"""Receive frames from Electron and route to the appropriate Redis channel."""
async for raw in websocket.iter_text():
try:
frame: dict = json.loads(raw)
except json.JSONDecodeError:
logger.warning("handler: invalid JSON from user=%s", user_id)
continue
frame_type = frame.get("type")
# Inject user_id so downstream services know who sent it
frame["user_id"] = user_id
if frame_type == WsFrameType.tool_result:
call_id = frame.get("id")
if call_id:
await push_tool_result(call_id, frame)
else:
logger.warning("handler: tool_result missing id user=%s", user_id)
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
await publish_chat_request(user_id, frame)
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
await publish_batch_request(user_id, frame)
elif frame_type == "pong":
pass # heartbeat ack
else:
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
# ── Outbound: Redis → Electron ───────────────────────────────────────
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
while True:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message is not None and message["type"] == "message":
await websocket.send_text(message["data"])
else:
# Brief sleep to avoid busy-wait when no messages
await asyncio.sleep(0.01)
# ── Heartbeat ────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:
"""Send ping frames every 30s to keep the connection alive."""
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"type": "ping"}))