step-5: unify ws handler (device_ws.py, chat.py)
- device_ws.py: dispatch home_request/popup_request to HomeFormatter/PopupFormatter via async tasks; each request gets a UUID request_id for frame correlation - chat.py: remove chat_stream WS endpoint (superseded by unified device WS); keep POST /chat REST fallback unchanged - 5 new integration tests pass; all 22 existing device_ws tests still pass Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -253,7 +253,7 @@ pytest tests/test_ws_unified.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Status**:
|
**Status**:
|
||||||
- [ ] Step 5 complete
|
- [x] Step 5 complete
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
**Commit**: After tests pass, commit with:
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,23 +1,19 @@
|
|||||||
"""Chat routes: POST /chat and WebSocket /chat/stream."""
|
"""Chat routes: POST /chat (REST fallback).
|
||||||
|
|
||||||
|
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
from fastapi import APIRouter, Depends
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.core.orchestrator import orchestrate
|
||||||
from app.core.orchestrator import orchestrate, orchestrate_stream
|
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.schemas import ChatRequest, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def chat(
|
async def chat(
|
||||||
@@ -31,48 +27,3 @@ async def chat(
|
|||||||
"""
|
"""
|
||||||
result = await orchestrate(body)
|
result = await orchestrate(body)
|
||||||
return JSONResponse(content=result.model_dump())
|
return JSONResponse(content=result.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/stream")
|
|
||||||
async def chat_stream(websocket: WebSocket) -> None:
|
|
||||||
"""Streaming chat via WebSocket.
|
|
||||||
|
|
||||||
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
|
|
||||||
|
|
||||||
Protocol:
|
|
||||||
1. Client sends ``ChatRequest`` as the first JSON text frame.
|
|
||||||
2. Server streams response text chunks.
|
|
||||||
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
|
|
||||||
4. Server pings every 30 s to keep the connection alive.
|
|
||||||
"""
|
|
||||||
# Authenticate before accepting the connection
|
|
||||||
token = websocket.query_params.get("token", "")
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
if not user_id:
|
|
||||||
raise JWTError("missing sub")
|
|
||||||
except JWTError:
|
|
||||||
await websocket.close(code=1008) # 1008 = Policy Violation
|
|
||||||
return
|
|
||||||
|
|
||||||
await websocket.accept()
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw = await websocket.receive_text()
|
|
||||||
body = ChatRequest.model_validate_json(raw)
|
|
||||||
|
|
||||||
async def _heartbeat() -> None:
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
||||||
await websocket.send_text(json.dumps({"ping": True}))
|
|
||||||
|
|
||||||
heartbeat_task = asyncio.create_task(_heartbeat())
|
|
||||||
try:
|
|
||||||
async for chunk in orchestrate_stream(body):
|
|
||||||
await websocket.send_text(chunk)
|
|
||||||
finally:
|
|
||||||
heartbeat_task.cancel()
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -33,14 +33,18 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
|
from app.core.orchestrator import orchestrate_v3_stream
|
||||||
|
from app.core.output_formatter import HomeFormatter, PopupFormatter
|
||||||
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
from app.schemas import WsFrameType
|
from app.schemas import WsFrameType
|
||||||
@@ -173,6 +177,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: agent_complete missing run_id from user=%s", user_id
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.home_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_home_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.popup_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_popup_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -183,6 +197,76 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
|
async def _executor(payload: dict) -> dict:
|
||||||
|
payload["type"] = WsFrameType.tool_call
|
||||||
|
await websocket.send_text(json.dumps(payload))
|
||||||
|
future = device_manager.create_pending_call(user_id, payload["id"])
|
||||||
|
return await future
|
||||||
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_home_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
context: dict = {
|
||||||
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
token_stream = orchestrate_v3_stream(user_id, message, context)
|
||||||
|
# Collect tool_results via the formatter after the stream completes.
|
||||||
|
# We pass an empty list initially; tool_results are populated during
|
||||||
|
# the agent run via ws_context._tool_result_collector (set inside _tool_loop_stream).
|
||||||
|
formatter = HomeFormatter(request_id=request_id, tool_results=[])
|
||||||
|
async for ws_frame in formatter.format(token_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: home_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_popup_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a popup_request frame — streams PopupFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
scope: dict = frame.get("scope", {})
|
||||||
|
context: dict = {"scope": scope}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
token_stream = orchestrate_v3_stream(user_id, message, context)
|
||||||
|
formatter = PopupFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(token_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: popup_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
|
|||||||
157
tests/test_ws_unified.py
Normal file
157
tests/test_ws_unified.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Integration tests for the unified WebSocket handler (Step 5).
|
||||||
|
|
||||||
|
Tests the device WS endpoint with home_request and popup_request frames,
|
||||||
|
verifying that the correct v3 frame sequence is returned.
|
||||||
|
|
||||||
|
LLM calls are mocked to avoid network dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
||||||
|
"""Receive frames until stream_end (or stream_end inside popup flow), or max_frames."""
|
||||||
|
frames = []
|
||||||
|
for _ in range(max_frames):
|
||||||
|
raw = ws.receive_text()
|
||||||
|
frame = json.loads(raw)
|
||||||
|
frames.append(frame)
|
||||||
|
if frame.get("type") == WsFrameType.stream_end:
|
||||||
|
break
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_home_stream(user_id, message, context, reg=None):
|
||||||
|
yield "task_agent", ""
|
||||||
|
yield "task_agent", '{"type": "text", "content": "Hello"}'
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_popup_stream(user_id, message, context, reg=None):
|
||||||
|
yield "task_agent", ""
|
||||||
|
yield "task_agent", "Here is a summary"
|
||||||
|
|
||||||
|
|
||||||
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_home_request_produces_stream_frames(client):
|
||||||
|
"""home_request → stream_start, stream_text+, stream_end."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": "r1",
|
||||||
|
"message": "List my tasks",
|
||||||
|
"conversation_history": [],
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
types = [f["type"] for f in frames]
|
||||||
|
assert WsFrameType.stream_start in types
|
||||||
|
assert WsFrameType.stream_end in types
|
||||||
|
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
|
|
||||||
|
def test_popup_request_produces_domain_frame(client):
|
||||||
|
"""popup_request → popup_domain first, then stream_text*, stream_end."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_popup_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "popup_request",
|
||||||
|
"request_id": "p1",
|
||||||
|
"message": "Summarize this task",
|
||||||
|
"scope": {"type": "task", "id": "task-123"},
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
types = [f["type"] for f in frames]
|
||||||
|
assert WsFrameType.popup_domain in types
|
||||||
|
assert WsFrameType.stream_end in types
|
||||||
|
assert types.index(WsFrameType.popup_domain) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
|
domain_frame = next(f for f in frames if f["type"] == WsFrameType.popup_domain)
|
||||||
|
assert domain_frame["domain"] == "tasks"
|
||||||
|
assert domain_frame["request_id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_request_id_propagated(client):
|
||||||
|
"""request_id in home_request is echoed in all response frames."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
req_id = "my-unique-req-id"
|
||||||
|
|
||||||
|
async def _stream(user_id, message, context, reg=None):
|
||||||
|
yield "note_agent", ""
|
||||||
|
yield "note_agent", '{"type": "text", "content": "ok"}'
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": req_id,
|
||||||
|
"message": "hello",
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
for f in frames:
|
||||||
|
if "request_id" in f:
|
||||||
|
assert f["request_id"] == req_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_dispatch_silent_on_unknown_id(client):
|
||||||
|
"""tool_result for unknown call_id is silently ignored — no crash."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-4", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "tool_result", "id": "no-such-id", "ok": True
|
||||||
|
}))
|
||||||
|
# If connection is still alive, we'll get the heartbeat ping
|
||||||
|
msg = json.loads(ws.receive_text())
|
||||||
|
assert msg["type"] == "ping"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_jwt_rejected(client):
|
||||||
|
"""Connection with bad token is closed before or after accept."""
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
|
||||||
|
ws.receive_text()
|
||||||
Reference in New Issue
Block a user