feat: add WS Gateway and Chat Service (Step 2)
WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)
Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
This commit is contained in:
173
services/ws-gateway/app/handler.py
Normal file
173
services/ws-gateway/app/handler.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""WebSocket handler — device connection lifecycle.
|
||||
|
||||
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
|
||||
and runs two concurrent loops:
|
||||
1. Message loop: receive frames from Electron, route to Redis
|
||||
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
|
||||
3. Heartbeat loop: ping every 30s
|
||||
|
||||
No business logic lives here — the handler is a JSON frame router.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from shared.config import settings
|
||||
from shared.schemas import WsFrameType
|
||||
|
||||
from app.redis_bridge import (
|
||||
publish_batch_request,
|
||||
publish_chat_request,
|
||||
push_tool_result,
|
||||
register_device,
|
||||
set_gateway_id,
|
||||
subscribe_outbound,
|
||||
unregister_device,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
|
||||
# Set a unique gateway instance ID on module load
|
||||
set_gateway_id(str(uuid4()))
|
||||
|
||||
|
||||
@router.websocket("/device")
|
||||
async def device_ws(websocket: WebSocket) -> None:
|
||||
"""Persistent WebSocket endpoint for Electron device connections."""
|
||||
|
||||
# ── 1. Authenticate via ?token= query parameter ──────────────────
|
||||
token = websocket.query_params.get("token", "")
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_PUBLIC_KEY,
|
||||
algorithms=["RS256"],
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
if not user_id:
|
||||
raise JWTError("missing sub")
|
||||
except JWTError:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# ── 2. Await device_hello frame ──────────────────────────────────
|
||||
try:
|
||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
try:
|
||||
hello = json.loads(raw)
|
||||
if hello.get("type") != WsFrameType.device_hello:
|
||||
raise ValueError("expected device_hello as first frame")
|
||||
device_id: str = hello["device_id"]
|
||||
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# ── 3. Register device in Redis ──────────────────────────────────
|
||||
await register_device(user_id, device_id)
|
||||
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
|
||||
|
||||
# Notify downstream services that device is online (for agent trigger)
|
||||
await publish_batch_request(user_id, {
|
||||
"type": "device_online",
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"agent_ids": agent_ids,
|
||||
})
|
||||
|
||||
# ── 4. Subscribe to outbound Redis channel ───────────────────────
|
||||
pubsub = await subscribe_outbound(user_id)
|
||||
|
||||
# ── 5. Run concurrent loops ──────────────────────────────────────
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_inbound_loop(websocket, user_id),
|
||||
_outbound_loop(websocket, pubsub),
|
||||
_heartbeat_loop(websocket),
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
|
||||
finally:
|
||||
await pubsub.unsubscribe()
|
||||
await pubsub.aclose()
|
||||
await unregister_device(user_id)
|
||||
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
|
||||
|
||||
|
||||
# ── Inbound: Electron → Redis ────────────────────────────────────────
|
||||
|
||||
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
"""Receive frames from Electron and route to the appropriate Redis channel."""
|
||||
async for raw in websocket.iter_text():
|
||||
try:
|
||||
frame: dict = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("handler: invalid JSON from user=%s", user_id)
|
||||
continue
|
||||
|
||||
frame_type = frame.get("type")
|
||||
|
||||
# Inject user_id so downstream services know who sent it
|
||||
frame["user_id"] = user_id
|
||||
|
||||
if frame_type == WsFrameType.tool_result:
|
||||
call_id = frame.get("id")
|
||||
if call_id:
|
||||
await push_tool_result(call_id, frame)
|
||||
else:
|
||||
logger.warning("handler: tool_result missing id user=%s", user_id)
|
||||
|
||||
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
|
||||
await publish_chat_request(user_id, frame)
|
||||
|
||||
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
|
||||
await publish_batch_request(user_id, frame)
|
||||
|
||||
elif frame_type == "pong":
|
||||
pass # heartbeat ack
|
||||
|
||||
else:
|
||||
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
|
||||
|
||||
|
||||
# ── Outbound: Redis → Electron ───────────────────────────────────────
|
||||
|
||||
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
|
||||
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
|
||||
while True:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
if message is not None and message["type"] == "message":
|
||||
await websocket.send_text(message["data"])
|
||||
else:
|
||||
# Brief sleep to avoid busy-wait when no messages
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
# ── Heartbeat ────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
"""Send ping frames every 30s to keep the connection alive."""
|
||||
while True:
|
||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||
49
services/ws-gateway/app/main.py
Normal file
49
services/ws-gateway/app/main.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""WS Gateway — stateless WebSocket proxy.
|
||||
|
||||
Accepts Electron device connections, authenticates JWT (RS256 public key),
|
||||
and routes frames between Electron and downstream services via Redis pub/sub.
|
||||
|
||||
This service has NO business logic — it only routes JSON frames.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from shared.config import settings
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
from shared.redis import redis_client
|
||||
|
||||
await redis_client.aclose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="Adiuva WS Gateway",
|
||||
version="0.1.0",
|
||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
from app.handler import router
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "service": "ws-gateway", "version": app.version}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
104
services/ws-gateway/app/redis_bridge.py
Normal file
104
services/ws-gateway/app/redis_bridge.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Redis bridge — device registry + pub/sub routing.
|
||||
|
||||
All inter-service communication passes through Redis:
|
||||
- Device registry: HSET/HDEL ws:devices:{user_id}
|
||||
- Outbound frames: Subscribe ws:out:{user_id}
|
||||
- Chat requests: Publish chat:request:{user_id}
|
||||
- Batch requests: Publish batch:request:{user_id}
|
||||
- Tool results: LPUSH tool:result:{call_id}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from shared.redis import (
|
||||
batch_request_channel,
|
||||
chat_request_channel,
|
||||
device_key,
|
||||
redis_client,
|
||||
tool_result_key,
|
||||
ws_out_channel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Instance ID for this gateway replica (set on startup)
|
||||
_GATEWAY_ID: str = ""
|
||||
|
||||
|
||||
def set_gateway_id(gid: str) -> None:
|
||||
global _GATEWAY_ID
|
||||
_GATEWAY_ID = gid
|
||||
|
||||
|
||||
def get_gateway_id() -> str:
|
||||
return _GATEWAY_ID
|
||||
|
||||
|
||||
# ── Device Registry ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def register_device(user_id: str, device_id: str) -> None:
|
||||
"""Register a connected device in Redis."""
|
||||
key = device_key(user_id)
|
||||
await redis_client.hset(key, mapping={
|
||||
"device_id": device_id,
|
||||
"gateway_id": _GATEWAY_ID,
|
||||
})
|
||||
logger.info("redis_bridge: registered user=%s device=%s gateway=%s", user_id, device_id, _GATEWAY_ID)
|
||||
|
||||
|
||||
async def unregister_device(user_id: str) -> None:
|
||||
"""Remove device registration from Redis."""
|
||||
key = device_key(user_id)
|
||||
await redis_client.delete(key)
|
||||
logger.info("redis_bridge: unregistered user=%s", user_id)
|
||||
|
||||
|
||||
async def is_device_online(user_id: str) -> bool:
|
||||
"""Check if a device is registered."""
|
||||
key = device_key(user_id)
|
||||
return await redis_client.exists(key) > 0
|
||||
|
||||
|
||||
# ── Frame Routing ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def publish_chat_request(user_id: str, frame: dict) -> None:
|
||||
"""Forward a chat request frame to the Chat Service via Redis."""
|
||||
channel = chat_request_channel(user_id)
|
||||
await redis_client.publish(channel, json.dumps(frame))
|
||||
logger.debug("redis_bridge: published chat_request user=%s", user_id)
|
||||
|
||||
|
||||
async def publish_batch_request(user_id: str, frame: dict) -> None:
|
||||
"""Forward a batch request frame to the Batch Agent Service via Redis."""
|
||||
channel = batch_request_channel(user_id)
|
||||
await redis_client.publish(channel, json.dumps(frame))
|
||||
logger.debug("redis_bridge: published batch_request user=%s", user_id)
|
||||
|
||||
|
||||
async def push_tool_result(call_id: str, result: dict) -> None:
|
||||
"""Push a tool_result to the Redis list for the waiting service.
|
||||
|
||||
Chat/Batch services do BRPOP on this key with a 30s timeout.
|
||||
"""
|
||||
key = tool_result_key(call_id)
|
||||
await redis_client.lpush(key, json.dumps(result))
|
||||
# Auto-expire after 60s to prevent stale keys
|
||||
await redis_client.expire(key, 60)
|
||||
logger.debug("redis_bridge: pushed tool_result call_id=%s", call_id)
|
||||
|
||||
|
||||
async def subscribe_outbound(user_id: str):
|
||||
"""Return an async pubsub subscription for frames to send to Electron.
|
||||
|
||||
Chat/Batch services publish to ws:out:{user_id} and this gateway
|
||||
forwards them to the connected WebSocket.
|
||||
"""
|
||||
channel = ws_out_channel(user_id)
|
||||
pubsub = redis_client.pubsub()
|
||||
await pubsub.subscribe(channel)
|
||||
return pubsub
|
||||
Reference in New Issue
Block a user