504 lines
18 KiB
Python
504 lines
18 KiB
Python
"""Device WebSocket endpoint.
|
|
|
|
Persistent connection from Electron devices to the backend.
|
|
|
|
WS /api/v1/ws/device?token=<jwt>
|
|
|
|
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
|
available during the WebSocket handshake).
|
|
|
|
Protocol:
|
|
1. Client connects → JWT validated → connection accepted.
|
|
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
|
3. Backend registers the connection in ``DeviceConnectionManager``.
|
|
4. Session enters message dispatch loop + heartbeat.
|
|
|
|
Incoming frame dispatch:
|
|
- ``tool_result`` → resolves a pending tool-call Future.
|
|
- ``journey_start`` → starts a guided setup journey session.
|
|
- ``journey_message`` → continues a journey conversation.
|
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
|
- unknown types → logged, ignored.
|
|
|
|
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
|
|
|
On disconnect:
|
|
- Unregisters from DeviceConnectionManager.
|
|
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
|
with message "device disconnected".
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from uuid import uuid4
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
from jose import JWTError, jwt
|
|
from sqlalchemy import update
|
|
|
|
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
|
from app.config.settings import settings
|
|
from app.core.agent_runner import trigger_pending_runs
|
|
from app.core.brief_agent import run_home_brief, run_project_brief
|
|
from app.core.deep_agent import run_floating_stream, run_home_stream
|
|
from app.core.device_manager import device_manager
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.core.output_formatter import StreamFormatter
|
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
|
from app.db import async_session
|
|
from app.models import AgentRunLog
|
|
from app.schemas import WsFrameType, WsStreamEnd
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
|
|
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
|
|
|
|
|
@router.websocket("/device")
|
|
async def device_ws(websocket: WebSocket) -> None:
|
|
"""Persistent WebSocket endpoint for Electron device connections.
|
|
|
|
Authentication is via ``?token=<jwt>`` query parameter.
|
|
"""
|
|
# ── 1. Authenticate before accepting ─────────────────────────────
|
|
token = websocket.query_params.get("token", "")
|
|
try:
|
|
payload = jwt.decode(
|
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
user_id: str | None = payload.get("sub")
|
|
if not user_id:
|
|
raise JWTError("missing sub")
|
|
except JWTError:
|
|
await websocket.close(code=1008) # Policy Violation
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
# ── 2. Await device_hello frame ───────────────────────────────────
|
|
try:
|
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
|
except (asyncio.TimeoutError, WebSocketDisconnect):
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
try:
|
|
hello = json.loads(raw)
|
|
if hello.get("type") != WsFrameType.device_hello:
|
|
raise ValueError("expected device_hello as first frame")
|
|
device_id: str = hello["device_id"]
|
|
agent_ids: list[str] = hello.get("agent_ids", [])
|
|
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
|
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
# ── 3. Register connection ────────────────────────────────────────
|
|
device_manager.register(user_id, device_id, websocket)
|
|
logger.info(
|
|
"device_ws: connected user=%s device=%s agents=%s",
|
|
user_id,
|
|
device_id,
|
|
agent_ids,
|
|
)
|
|
|
|
# Trigger any overdue agent runs now that the device is connected.
|
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
|
|
|
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
|
try:
|
|
await asyncio.gather(
|
|
_message_loop(websocket, user_id),
|
|
_heartbeat_loop(websocket),
|
|
)
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as exc:
|
|
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
|
finally:
|
|
device_manager.unregister(user_id)
|
|
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
|
await _mark_runs_disconnected(user_id)
|
|
|
|
|
|
# ── Message dispatch loop ─────────────────────────────────────────────
|
|
|
|
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
|
async for raw in websocket.iter_text():
|
|
try:
|
|
frame: dict = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
|
continue
|
|
|
|
frame_type = frame.get("type")
|
|
|
|
if frame_type == WsFrameType.tool_result:
|
|
call_id = frame.get("id")
|
|
if call_id:
|
|
device_manager.resolve_pending_call(user_id, call_id, frame)
|
|
else:
|
|
logger.warning(
|
|
"device_ws: tool_result missing id from user=%s", user_id
|
|
)
|
|
|
|
elif frame_type == WsFrameType.home_request:
|
|
asyncio.create_task(
|
|
_handle_home_request(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.floating_request:
|
|
asyncio.create_task(
|
|
_handle_floating_request(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.brief_request:
|
|
asyncio.create_task(
|
|
_handle_brief_request(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.journey_start:
|
|
asyncio.create_task(
|
|
_handle_journey_start(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.journey_message:
|
|
asyncio.create_task(
|
|
_handle_journey_message(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == "pong":
|
|
# Heartbeat ack — nothing to do, connection is alive.
|
|
pass
|
|
|
|
else:
|
|
logger.debug(
|
|
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
|
)
|
|
|
|
|
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
|
|
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
|
async def _executor(payload: dict) -> dict:
|
|
payload["type"] = WsFrameType.tool_call
|
|
await websocket.send_text(json.dumps(payload))
|
|
future = device_manager.create_pending_call(user_id, payload["id"])
|
|
return await future
|
|
return _executor
|
|
|
|
|
|
async def _handle_home_request(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
message: str = frame.get("message", "")
|
|
session_id: str = frame.get("session_id") or str(uuid4())
|
|
logger.info(
|
|
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
message[:200],
|
|
)
|
|
|
|
# ── Memory: enrich context before LLM call ────────────────────────
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id,
|
|
message,
|
|
trace_id=request_id,
|
|
session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"conversation_history": frame.get("conversation_history", []),
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
**memory_context,
|
|
}
|
|
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
response_chunks: list[str] = []
|
|
try:
|
|
event_stream = run_home_stream(user_id, message, context)
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
await websocket.send_text(ws_frame.model_dump_json())
|
|
# Collect text chunks to build the full response for episode storage
|
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: home_request failed user=%s req=%s: %s",
|
|
user_id, request_id, exc,
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
# ── Memory: store episode after response ──────────────────────────
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.store_episode(
|
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
|
)
|
|
logger.info(
|
|
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
len("".join(response_chunks)),
|
|
)
|
|
|
|
|
|
async def _handle_floating_request(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
message: str = frame.get("message", "")
|
|
session_id: str = frame.get("session_id") or str(uuid4())
|
|
scope: dict = frame.get("scope", {})
|
|
logger.info(
|
|
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
json.dumps(scope, ensure_ascii=True)[:200],
|
|
message[:200],
|
|
)
|
|
|
|
# ── Memory: enrich context before LLM call ────────────────────────
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id,
|
|
message,
|
|
trace_id=request_id,
|
|
session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"scope": scope,
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
**memory_context,
|
|
}
|
|
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
response_chunks: list[str] = []
|
|
try:
|
|
event_stream = run_floating_stream(user_id, message, context)
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
await websocket.send_text(ws_frame.model_dump_json())
|
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: floating_request failed user=%s req=%s: %s",
|
|
user_id, request_id, exc,
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
# ── Memory: store episode after response ──────────────────────────
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.store_episode(
|
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
|
)
|
|
logger.info(
|
|
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
len("".join(response_chunks)),
|
|
)
|
|
|
|
|
|
async def _handle_brief_request(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a brief_request frame — streams plain-text brief back on the socket.
|
|
|
|
No episode storage — briefs are not conversations.
|
|
"""
|
|
import uuid as _uuid
|
|
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
session_id = frame.get("session_id") or str(uuid4())
|
|
mode: str = frame.get("mode", "home")
|
|
project_id: str | None = frame.get("project_id")
|
|
|
|
logger.info(
|
|
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
|
|
user_id, request_id, mode, project_id,
|
|
)
|
|
|
|
# Validate project_id for project mode before touching LLM.
|
|
if mode == "project":
|
|
try:
|
|
if not project_id:
|
|
raise ValueError("project_id required for project mode")
|
|
_uuid.UUID(project_id)
|
|
except (ValueError, AttributeError) as exc:
|
|
logger.warning(
|
|
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
|
|
user_id, request_id, exc,
|
|
)
|
|
await websocket.send_text(
|
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
|
)
|
|
return
|
|
|
|
# Enrich context with memory (no user message — use empty string as probe).
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id,
|
|
"",
|
|
trace_id=request_id,
|
|
session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
**memory_context,
|
|
}
|
|
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
try:
|
|
if mode == "project":
|
|
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
|
|
else:
|
|
event_stream = run_home_brief(user_id, context)
|
|
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
await websocket.send_text(ws_frame.model_dump_json())
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: brief_request failed user=%s req=%s: %s",
|
|
user_id, request_id, exc,
|
|
)
|
|
await websocket.send_text(
|
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
logger.info(
|
|
"device_ws: brief_request_end user=%s req=%s mode=%s",
|
|
user_id, request_id, mode,
|
|
)
|
|
|
|
|
|
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
|
|
|
|
|
async def _handle_journey_start(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a journey_start frame — explores directory and sends first question."""
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
try:
|
|
reply = await handle_journey_start(user_id, frame)
|
|
await websocket.send_text(json.dumps(reply))
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
|
)
|
|
await websocket.send_text(json.dumps({
|
|
"type": "journey_reply",
|
|
"session_id": frame.get("session_id", ""),
|
|
"message": f"Failed to start journey: {exc}",
|
|
"done": True,
|
|
"prompt_template": None,
|
|
}))
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
|
|
async def _handle_journey_message(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a journey_message frame — continues the journey conversation."""
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
try:
|
|
reply = await handle_journey_message(user_id, frame)
|
|
await websocket.send_text(json.dumps(reply))
|
|
except Exception as exc:
|
|
session_id = frame.get("session_id", "")
|
|
logger.error(
|
|
"device_ws: journey_message failed user=%s session=%s: %s",
|
|
user_id, session_id, exc,
|
|
)
|
|
await websocket.send_text(json.dumps({
|
|
"type": "journey_reply",
|
|
"session_id": session_id,
|
|
"message": f"Journey error: {exc}",
|
|
"done": True,
|
|
"prompt_template": None,
|
|
}))
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
|
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
|
|
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
|
"""Send a ping frame every 30 s to keep the connection alive."""
|
|
while True:
|
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
await websocket.send_text(json.dumps({"type": "ping"}))
|
|
|
|
|
|
# ── Disconnect cleanup ────────────────────────────────────────────────
|
|
|
|
async def _mark_runs_disconnected(user_id: str) -> None:
|
|
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
|
try:
|
|
async with async_session() as db:
|
|
await db.execute(
|
|
update(AgentRunLog)
|
|
.where(
|
|
AgentRunLog.user_id == user_id,
|
|
AgentRunLog.status == "running",
|
|
)
|
|
.values(
|
|
status="error",
|
|
errors=["device disconnected"],
|
|
)
|
|
)
|
|
await db.commit()
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
|
user_id,
|
|
exc,
|
|
)
|