From 76c8f2bdad144383e3c986a0a9b83bc404c84327 Mon Sep 17 00:00:00 2001 From: roberto Date: Sun, 8 Mar 2026 22:01:11 +0100 Subject: [PATCH] step-5: unify ws handler (device_ws.py, chat.py) - device_ws.py: dispatch home_request/popup_request to HomeFormatter/PopupFormatter via async tasks; each request gets a UUID request_id for frame correlation - chat.py: remove chat_stream WS endpoint (superseded by unified device WS); keep POST /chat REST fallback unchanged - 5 new integration tests pass; all 22 existing device_ws tests still pass Co-Authored-By: Claude Sonnet 4.6 --- V3_MIGRATION_PLAN.md | 2 +- app/api/routes/chat.py | 61 ++------------ app/api/routes/device_ws.py | 86 +++++++++++++++++++- tests/test_ws_unified.py | 157 ++++++++++++++++++++++++++++++++++++ 4 files changed, 249 insertions(+), 57 deletions(-) create mode 100644 tests/test_ws_unified.py diff --git a/V3_MIGRATION_PLAN.md b/V3_MIGRATION_PLAN.md index 30eca16..d2ef537 100644 --- a/V3_MIGRATION_PLAN.md +++ b/V3_MIGRATION_PLAN.md @@ -253,7 +253,7 @@ pytest tests/test_ws_unified.py ``` **Status**: -- [ ] Step 5 complete +- [x] Step 5 complete **Commit**: After tests pass, commit with: ``` diff --git a/app/api/routes/chat.py b/app/api/routes/chat.py index ba0a6ff..1cd0fa4 100644 --- a/app/api/routes/chat.py +++ b/app/api/routes/chat.py @@ -1,23 +1,19 @@ -"""Chat routes: POST /chat and WebSocket /chat/stream.""" +"""Chat routes: POST /chat (REST fallback). + +WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device). +""" from __future__ import annotations -import asyncio -import json - -from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse -from jose import JWTError, jwt from app.api.deps import get_current_user -from app.config.settings import settings -from app.core.orchestrator import orchestrate, orchestrate_stream +from app.core.orchestrator import orchestrate from app.schemas import ChatRequest, UserProfile router = APIRouter(prefix="/chat", tags=["chat"]) -_HEARTBEAT_INTERVAL = 30 # seconds - @router.post("") async def chat( @@ -31,48 +27,3 @@ async def chat( """ result = await orchestrate(body) return JSONResponse(content=result.model_dump()) - - -@router.websocket("/stream") -async def chat_stream(websocket: WebSocket) -> None: - """Streaming chat via WebSocket. - - Auth: ``?token=`` query param (Bearer not possible during WS handshake). - - Protocol: - 1. Client sends ``ChatRequest`` as the first JSON text frame. - 2. Server streams response text chunks. - 3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``. - 4. Server pings every 30 s to keep the connection alive. - """ - # Authenticate before accepting the connection - token = websocket.query_params.get("token", "") - try: - payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]) - user_id: str | None = payload.get("sub") - if not user_id: - raise JWTError("missing sub") - except JWTError: - await websocket.close(code=1008) # 1008 = Policy Violation - return - - await websocket.accept() - - try: - raw = await websocket.receive_text() - body = ChatRequest.model_validate_json(raw) - - async def _heartbeat() -> None: - while True: - await asyncio.sleep(_HEARTBEAT_INTERVAL) - await websocket.send_text(json.dumps({"ping": True})) - - heartbeat_task = asyncio.create_task(_heartbeat()) - try: - async for chunk in orchestrate_stream(body): - await websocket.send_text(chunk) - finally: - heartbeat_task.cancel() - - except WebSocketDisconnect: - pass diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py index 2e0c038..0b3e4ad 100644 --- a/app/api/routes/device_ws.py +++ b/app/api/routes/device_ws.py @@ -33,14 +33,18 @@ from __future__ import annotations import asyncio import json import logging +from uuid import uuid4 from fastapi import APIRouter, WebSocket, WebSocketDisconnect from jose import JWTError, jwt -from sqlalchemy import select, update +from sqlalchemy import update from app.config.settings import settings from app.core.agent_runner import trigger_pending_runs from app.core.device_manager import device_manager +from app.core.orchestrator import orchestrate_v3_stream +from app.core.output_formatter import HomeFormatter, PopupFormatter +from app.core.ws_context import clear_client_executor, set_client_executor from app.db import async_session from app.models import AgentRunLog from app.schemas import WsFrameType @@ -173,6 +177,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None: "device_ws: agent_complete missing run_id from user=%s", user_id ) + elif frame_type == WsFrameType.home_request: + asyncio.create_task( + _handle_home_request(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.popup_request: + asyncio.create_task( + _handle_popup_request(websocket, user_id, frame) + ) + elif frame_type == "pong": # Heartbeat ack — nothing to do, connection is alive. pass @@ -183,6 +197,76 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None: ) +# ── v3 Chat Handlers ────────────────────────────────────────────────── + +async def _make_ws_executor(websocket: WebSocket, user_id: str): + """Return a callback that sends tool_call frames and awaits tool_result.""" + async def _executor(payload: dict) -> dict: + payload["type"] = WsFrameType.tool_call + await websocket.send_text(json.dumps(payload)) + future = device_manager.create_pending_call(user_id, payload["id"]) + return await future + return _executor + + +async def _handle_home_request( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a home_request frame — streams HomeFormatter output back on the socket.""" + request_id = frame.get("request_id") or str(uuid4()) + message: str = frame.get("message", "") + context: dict = { + "conversation_history": frame.get("conversation_history", []), + } + + executor = await _make_ws_executor(websocket, user_id) + set_client_executor(executor) + try: + token_stream = orchestrate_v3_stream(user_id, message, context) + # Collect tool_results via the formatter after the stream completes. + # We pass an empty list initially; tool_results are populated during + # the agent run via ws_context._tool_result_collector (set inside _tool_loop_stream). + formatter = HomeFormatter(request_id=request_id, tool_results=[]) + async for ws_frame in formatter.format(token_stream): + await websocket.send_text(ws_frame.model_dump_json()) + except Exception as exc: + logger.error( + "device_ws: home_request failed user=%s req=%s: %s", + user_id, request_id, exc, + ) + finally: + clear_client_executor() + + +async def _handle_popup_request( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a popup_request frame — streams PopupFormatter output back on the socket.""" + request_id = frame.get("request_id") or str(uuid4()) + message: str = frame.get("message", "") + scope: dict = frame.get("scope", {}) + context: dict = {"scope": scope} + + executor = await _make_ws_executor(websocket, user_id) + set_client_executor(executor) + try: + token_stream = orchestrate_v3_stream(user_id, message, context) + formatter = PopupFormatter(request_id=request_id) + async for ws_frame in formatter.format(token_stream): + await websocket.send_text(ws_frame.model_dump_json()) + except Exception as exc: + logger.error( + "device_ws: popup_request failed user=%s req=%s: %s", + user_id, request_id, exc, + ) + finally: + clear_client_executor() + + # ── Heartbeat ───────────────────────────────────────────────────────── async def _heartbeat_loop(websocket: WebSocket) -> None: diff --git a/tests/test_ws_unified.py b/tests/test_ws_unified.py new file mode 100644 index 0000000..7eb7337 --- /dev/null +++ b/tests/test_ws_unified.py @@ -0,0 +1,157 @@ +"""Integration tests for the unified WebSocket handler (Step 5). + +Tests the device WS endpoint with home_request and popup_request frames, +verifying that the correct v3 frame sequence is returned. + +LLM calls are mocked to avoid network dependency. +""" + +from __future__ import annotations + +import json +from unittest.mock import patch + +import pytest + +from app.db import get_session +from app.main import app +from app.schemas import WsFrameType +from tests.conftest import TEST_USER_IDS, make_jwt + +USER_ID = TEST_USER_IDS["power"] + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def _override_db(db_session): + async def _gen(): + yield db_session + + app.dependency_overrides[get_session] = _gen + yield + app.dependency_overrides.pop(get_session, None) + + +def _recv_until_end(ws, max_frames: int = 20) -> list[dict]: + """Receive frames until stream_end (or stream_end inside popup flow), or max_frames.""" + frames = [] + for _ in range(max_frames): + raw = ws.receive_text() + frame = json.loads(raw) + frames.append(frame) + if frame.get("type") == WsFrameType.stream_end: + break + return frames + + +async def _mock_home_stream(user_id, message, context, reg=None): + yield "task_agent", "" + yield "task_agent", '{"type": "text", "content": "Hello"}' + + +async def _mock_popup_stream(user_id, message, context, reg=None): + yield "task_agent", "" + yield "task_agent", "Here is a summary" + + +# ── tests ───────────────────────────────────────────────────────────────────── + +def test_home_request_produces_stream_frames(client): + """home_request → stream_start, stream_text+, stream_end.""" + token = make_jwt("power", user_id=USER_ID) + + with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream): + with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: + ws.send_text(json.dumps({ + "type": "device_hello", "device_id": "dev-1", "agent_ids": [] + })) + ws.send_text(json.dumps({ + "type": "home_request", + "request_id": "r1", + "message": "List my tasks", + "conversation_history": [], + })) + frames = _recv_until_end(ws) + + types = [f["type"] for f in frames] + assert WsFrameType.stream_start in types + assert WsFrameType.stream_end in types + assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end) + + +def test_popup_request_produces_domain_frame(client): + """popup_request → popup_domain first, then stream_text*, stream_end.""" + token = make_jwt("power", user_id=USER_ID) + + with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_popup_stream): + with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: + ws.send_text(json.dumps({ + "type": "device_hello", "device_id": "dev-2", "agent_ids": [] + })) + ws.send_text(json.dumps({ + "type": "popup_request", + "request_id": "p1", + "message": "Summarize this task", + "scope": {"type": "task", "id": "task-123"}, + })) + frames = _recv_until_end(ws) + + types = [f["type"] for f in frames] + assert WsFrameType.popup_domain in types + assert WsFrameType.stream_end in types + assert types.index(WsFrameType.popup_domain) < types.index(WsFrameType.stream_end) + + domain_frame = next(f for f in frames if f["type"] == WsFrameType.popup_domain) + assert domain_frame["domain"] == "tasks" + assert domain_frame["request_id"] == "p1" + + +def test_home_request_request_id_propagated(client): + """request_id in home_request is echoed in all response frames.""" + token = make_jwt("power", user_id=USER_ID) + req_id = "my-unique-req-id" + + async def _stream(user_id, message, context, reg=None): + yield "note_agent", "" + yield "note_agent", '{"type": "text", "content": "ok"}' + + with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream): + with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: + ws.send_text(json.dumps({ + "type": "device_hello", "device_id": "dev-3", "agent_ids": [] + })) + ws.send_text(json.dumps({ + "type": "home_request", + "request_id": req_id, + "message": "hello", + })) + frames = _recv_until_end(ws) + + for f in frames: + if "request_id" in f: + assert f["request_id"] == req_id + + +def test_tool_result_dispatch_silent_on_unknown_id(client): + """tool_result for unknown call_id is silently ignored — no crash.""" + token = make_jwt("power", user_id=USER_ID) + + with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05): + with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: + ws.send_text(json.dumps({ + "type": "device_hello", "device_id": "dev-4", "agent_ids": [] + })) + ws.send_text(json.dumps({ + "type": "tool_result", "id": "no-such-id", "ok": True + })) + # If connection is still alive, we'll get the heartbeat ping + msg = json.loads(ws.receive_text()) + assert msg["type"] == "ping" + + +def test_invalid_jwt_rejected(client): + """Connection with bad token is closed before or after accept.""" + with pytest.raises(Exception): + with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws: + ws.receive_text()