Files
api/tests/test_ws_unified.py
roberto 617a17db40 feat: HomeFormatter parses inline entity tags instead of tool_end blocks
The supervisor LLM now embeds <type>[id1,id2]</type> entity tags in its
response text. The HomeFormatter buffers streamed tokens, detects complete
tags across chunk boundaries, and emits WsStreamBlock with entity type +
specific IDs. This replaces the old approach of emitting blocks for every
tool_end event, which dumped ALL entities regardless of relevance.

Also fixes:
- NoneType guard on metadata in _run_graph_stream (metadata can be None)
- Updated _HOME_SYSTEM prompt with entity tag instructions
- Updated all affected tests
2026-03-12 00:01:06 +01:00

159 lines
5.9 KiB
Python

"""Integration tests for the unified WebSocket handler (Step 5).
Tests the device WS endpoint with home_request and floating_request frames,
verifying that the correct v3 frame sequence is returned.
LLM calls are mocked to avoid network dependency.
"""
from __future__ import annotations
import json
from unittest.mock import patch
import pytest
from app.db import get_session
from app.main import app
from app.schemas import WsFrameType
from tests.conftest import TEST_USER_IDS, make_jwt
USER_ID = TEST_USER_IDS["power"]
# ── helpers ───────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
frames = []
for _ in range(max_frames):
raw = ws.receive_text()
frame = json.loads(raw)
frames.append(frame)
if frame.get("type") == WsFrameType.stream_end:
break
return frames
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
yield "mutations", []
async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
yield "tool_end", {"name": "task_agent", "result": "ok"}
yield "token", "Here is a summary"
yield "mutations", []
# ── tests ─────────────────────────────────────────────────────────────────────
def test_home_request_produces_stream_frames(client):
"""home_request → stream_start, stream_text+, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "home_request",
"request_id": "r1",
"message": "List my tasks",
"conversation_history": [],
}))
frames = _recv_until_end(ws)
types = [f["type"] for f in frames]
assert WsFrameType.stream_start in types
assert WsFrameType.stream_end in types
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
def test_floating_request_produces_domain_frame(client):
"""floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "floating_request",
"request_id": "p1",
"message": "Summarize this task",
"scope": {"type": "task", "id": "task-123"},
}))
frames = _recv_until_end(ws)
types = [f["type"] for f in frames]
assert WsFrameType.floating_domain in types
assert WsFrameType.stream_end in types
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"] == "tasks"
assert domain_frame["request_id"] == "p1"
def test_home_request_request_id_propagated(client):
"""request_id in home_request is echoed in all response frames."""
token = make_jwt("power", user_id=USER_ID)
req_id = "my-unique-req-id"
async def _stream(user_id, message, context, db_session_factory=None):
yield "token", "ok"
yield "mutations", []
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "home_request",
"request_id": req_id,
"message": "hello",
}))
frames = _recv_until_end(ws)
for f in frames:
if "request_id" in f:
assert f["request_id"] == req_id
def test_tool_result_dispatch_silent_on_unknown_id(client):
"""tool_result for unknown call_id is silently ignored — no crash."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-4", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "tool_result", "id": "no-such-id", "ok": True
}))
# If connection is still alive, we'll get the heartbeat ping
msg = json.loads(ws.receive_text())
assert msg["type"] == "ping"
def test_invalid_jwt_rejected(client):
"""Connection with bad token is closed before or after accept."""
with pytest.raises(Exception):
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
ws.receive_text()