feat: add WS Gateway and Chat Service (Step 2)

WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)

Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
This commit is contained in:
Roberto Musso
2026-03-22 01:20:11 +01:00
parent 1e2e395676
commit 90018af311
21 changed files with 2731 additions and 1 deletions

View File

@@ -0,0 +1,173 @@
"""WebSocket handler — device connection lifecycle.
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
and runs two concurrent loops:
1. Message loop: receive frames from Electron, route to Redis
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
3. Heartbeat loop: ping every 30s
No business logic lives here — the handler is a JSON frame router.
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from shared.config import settings
from shared.schemas import WsFrameType
from app.redis_bridge import (
publish_batch_request,
publish_chat_request,
push_tool_result,
register_device,
set_gateway_id,
subscribe_outbound,
unregister_device,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
_HEARTBEAT_INTERVAL = 30 # seconds
# Set a unique gateway instance ID on module load
set_gateway_id(str(uuid4()))
@router.websocket("/device")
async def device_ws(websocket: WebSocket) -> None:
"""Persistent WebSocket endpoint for Electron device connections."""
# ── 1. Authenticate via ?token= query parameter ──────────────────
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(
token,
settings.JWT_PUBLIC_KEY,
algorithms=["RS256"],
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008)
return
await websocket.accept()
# ── 2. Await device_hello frame ──────────────────────────────────
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
try:
hello = json.loads(raw)
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
await websocket.close(code=1008)
return
# ── 3. Register device in Redis ──────────────────────────────────
await register_device(user_id, device_id)
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
# Notify downstream services that device is online (for agent trigger)
await publish_batch_request(user_id, {
"type": "device_online",
"user_id": user_id,
"device_id": device_id,
"agent_ids": agent_ids,
})
# ── 4. Subscribe to outbound Redis channel ───────────────────────
pubsub = await subscribe_outbound(user_id)
# ── 5. Run concurrent loops ──────────────────────────────────────
try:
await asyncio.gather(
_inbound_loop(websocket, user_id),
_outbound_loop(websocket, pubsub),
_heartbeat_loop(websocket),
)
except WebSocketDisconnect:
pass
except Exception as exc:
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
finally:
await pubsub.unsubscribe()
await pubsub.aclose()
await unregister_device(user_id)
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
# ── Inbound: Electron → Redis ────────────────────────────────────────
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
"""Receive frames from Electron and route to the appropriate Redis channel."""
async for raw in websocket.iter_text():
try:
frame: dict = json.loads(raw)
except json.JSONDecodeError:
logger.warning("handler: invalid JSON from user=%s", user_id)
continue
frame_type = frame.get("type")
# Inject user_id so downstream services know who sent it
frame["user_id"] = user_id
if frame_type == WsFrameType.tool_result:
call_id = frame.get("id")
if call_id:
await push_tool_result(call_id, frame)
else:
logger.warning("handler: tool_result missing id user=%s", user_id)
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
await publish_chat_request(user_id, frame)
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
await publish_batch_request(user_id, frame)
elif frame_type == "pong":
pass # heartbeat ack
else:
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
# ── Outbound: Redis → Electron ───────────────────────────────────────
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
while True:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message is not None and message["type"] == "message":
await websocket.send_text(message["data"])
else:
# Brief sleep to avoid busy-wait when no messages
await asyncio.sleep(0.01)
# ── Heartbeat ────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:
"""Send ping frames every 30s to keep the connection alive."""
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"type": "ping"}))

View File

@@ -0,0 +1,49 @@
"""WS Gateway — stateless WebSocket proxy.
Accepts Electron device connections, authenticates JWT (RS256 public key),
and routes frames between Electron and downstream services via Redis pub/sub.
This service has NO business logic — it only routes JSON frames.
"""
from contextlib import asynccontextmanager
import logging
from fastapi import FastAPI
from shared.config import settings
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
from shared.redis import redis_client
await redis_client.aclose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva WS Gateway",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
from app.handler import router
app.include_router(router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "ws-gateway", "version": app.version}
return app
app = create_app()

View File

@@ -0,0 +1,104 @@
"""Redis bridge — device registry + pub/sub routing.
All inter-service communication passes through Redis:
- Device registry: HSET/HDEL ws:devices:{user_id}
- Outbound frames: Subscribe ws:out:{user_id}
- Chat requests: Publish chat:request:{user_id}
- Batch requests: Publish batch:request:{user_id}
- Tool results: LPUSH tool:result:{call_id}
"""
from __future__ import annotations
import json
import logging
from shared.redis import (
batch_request_channel,
chat_request_channel,
device_key,
redis_client,
tool_result_key,
ws_out_channel,
)
logger = logging.getLogger(__name__)
# Instance ID for this gateway replica (set on startup)
_GATEWAY_ID: str = ""
def set_gateway_id(gid: str) -> None:
global _GATEWAY_ID
_GATEWAY_ID = gid
def get_gateway_id() -> str:
return _GATEWAY_ID
# ── Device Registry ──────────────────────────────────────────────────
async def register_device(user_id: str, device_id: str) -> None:
"""Register a connected device in Redis."""
key = device_key(user_id)
await redis_client.hset(key, mapping={
"device_id": device_id,
"gateway_id": _GATEWAY_ID,
})
logger.info("redis_bridge: registered user=%s device=%s gateway=%s", user_id, device_id, _GATEWAY_ID)
async def unregister_device(user_id: str) -> None:
"""Remove device registration from Redis."""
key = device_key(user_id)
await redis_client.delete(key)
logger.info("redis_bridge: unregistered user=%s", user_id)
async def is_device_online(user_id: str) -> bool:
"""Check if a device is registered."""
key = device_key(user_id)
return await redis_client.exists(key) > 0
# ── Frame Routing ────────────────────────────────────────────────────
async def publish_chat_request(user_id: str, frame: dict) -> None:
"""Forward a chat request frame to the Chat Service via Redis."""
channel = chat_request_channel(user_id)
await redis_client.publish(channel, json.dumps(frame))
logger.debug("redis_bridge: published chat_request user=%s", user_id)
async def publish_batch_request(user_id: str, frame: dict) -> None:
"""Forward a batch request frame to the Batch Agent Service via Redis."""
channel = batch_request_channel(user_id)
await redis_client.publish(channel, json.dumps(frame))
logger.debug("redis_bridge: published batch_request user=%s", user_id)
async def push_tool_result(call_id: str, result: dict) -> None:
"""Push a tool_result to the Redis list for the waiting service.
Chat/Batch services do BRPOP on this key with a 30s timeout.
"""
key = tool_result_key(call_id)
await redis_client.lpush(key, json.dumps(result))
# Auto-expire after 60s to prevent stale keys
await redis_client.expire(key, 60)
logger.debug("redis_bridge: pushed tool_result call_id=%s", call_id)
async def subscribe_outbound(user_id: str):
"""Return an async pubsub subscription for frames to send to Electron.
Chat/Batch services publish to ws:out:{user_id} and this gateway
forwards them to the connected WebSocket.
"""
channel = ws_out_channel(user_id)
pubsub = redis_client.pubsub()
await pubsub.subscribe(channel)
return pubsub