feat: add WS Gateway and Chat Service (Step 2)
WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)
Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
This commit is contained in:
173
services/ws-gateway/app/handler.py
Normal file
173
services/ws-gateway/app/handler.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""WebSocket handler — device connection lifecycle.
|
||||
|
||||
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
|
||||
and runs two concurrent loops:
|
||||
1. Message loop: receive frames from Electron, route to Redis
|
||||
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
|
||||
3. Heartbeat loop: ping every 30s
|
||||
|
||||
No business logic lives here — the handler is a JSON frame router.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from shared.config import settings
|
||||
from shared.schemas import WsFrameType
|
||||
|
||||
from app.redis_bridge import (
|
||||
publish_batch_request,
|
||||
publish_chat_request,
|
||||
push_tool_result,
|
||||
register_device,
|
||||
set_gateway_id,
|
||||
subscribe_outbound,
|
||||
unregister_device,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
|
||||
# Set a unique gateway instance ID on module load
|
||||
set_gateway_id(str(uuid4()))
|
||||
|
||||
|
||||
@router.websocket("/device")
|
||||
async def device_ws(websocket: WebSocket) -> None:
|
||||
"""Persistent WebSocket endpoint for Electron device connections."""
|
||||
|
||||
# ── 1. Authenticate via ?token= query parameter ──────────────────
|
||||
token = websocket.query_params.get("token", "")
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_PUBLIC_KEY,
|
||||
algorithms=["RS256"],
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
if not user_id:
|
||||
raise JWTError("missing sub")
|
||||
except JWTError:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# ── 2. Await device_hello frame ──────────────────────────────────
|
||||
try:
|
||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
try:
|
||||
hello = json.loads(raw)
|
||||
if hello.get("type") != WsFrameType.device_hello:
|
||||
raise ValueError("expected device_hello as first frame")
|
||||
device_id: str = hello["device_id"]
|
||||
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# ── 3. Register device in Redis ──────────────────────────────────
|
||||
await register_device(user_id, device_id)
|
||||
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
|
||||
|
||||
# Notify downstream services that device is online (for agent trigger)
|
||||
await publish_batch_request(user_id, {
|
||||
"type": "device_online",
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"agent_ids": agent_ids,
|
||||
})
|
||||
|
||||
# ── 4. Subscribe to outbound Redis channel ───────────────────────
|
||||
pubsub = await subscribe_outbound(user_id)
|
||||
|
||||
# ── 5. Run concurrent loops ──────────────────────────────────────
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_inbound_loop(websocket, user_id),
|
||||
_outbound_loop(websocket, pubsub),
|
||||
_heartbeat_loop(websocket),
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
|
||||
finally:
|
||||
await pubsub.unsubscribe()
|
||||
await pubsub.aclose()
|
||||
await unregister_device(user_id)
|
||||
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
|
||||
|
||||
|
||||
# ── Inbound: Electron → Redis ────────────────────────────────────────
|
||||
|
||||
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
"""Receive frames from Electron and route to the appropriate Redis channel."""
|
||||
async for raw in websocket.iter_text():
|
||||
try:
|
||||
frame: dict = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("handler: invalid JSON from user=%s", user_id)
|
||||
continue
|
||||
|
||||
frame_type = frame.get("type")
|
||||
|
||||
# Inject user_id so downstream services know who sent it
|
||||
frame["user_id"] = user_id
|
||||
|
||||
if frame_type == WsFrameType.tool_result:
|
||||
call_id = frame.get("id")
|
||||
if call_id:
|
||||
await push_tool_result(call_id, frame)
|
||||
else:
|
||||
logger.warning("handler: tool_result missing id user=%s", user_id)
|
||||
|
||||
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
|
||||
await publish_chat_request(user_id, frame)
|
||||
|
||||
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
|
||||
await publish_batch_request(user_id, frame)
|
||||
|
||||
elif frame_type == "pong":
|
||||
pass # heartbeat ack
|
||||
|
||||
else:
|
||||
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
|
||||
|
||||
|
||||
# ── Outbound: Redis → Electron ───────────────────────────────────────
|
||||
|
||||
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
|
||||
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
|
||||
while True:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
if message is not None and message["type"] == "message":
|
||||
await websocket.send_text(message["data"])
|
||||
else:
|
||||
# Brief sleep to avoid busy-wait when no messages
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
# ── Heartbeat ────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
"""Send ping frames every 30s to keep the connection alive."""
|
||||
while True:
|
||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||
Reference in New Issue
Block a user