rename popup chat to floating chat

This commit is contained in:
2026-03-08 22:53:31 +01:00
parent 0bd46937d3
commit 34f01234c9
8 changed files with 102 additions and 102 deletions

View File

@@ -36,18 +36,18 @@ This keeps the codebase clean and prevents confusion. When removing code, note i
**Changes**: **Changes**:
- `app/schemas.py` — Add to `WsFrameType` enum: - `app/schemas.py` — Add to `WsFrameType` enum:
- `home_request`, `popup_request` - `home_request`, `floating_request`
- `stream_start`, `stream_text`, `stream_block`, `stream_end` - `stream_start`, `stream_text`, `stream_block`, `stream_end`
- `popup_domain` - `floating_domain`
- `data_request`, `data_response`, `mutation` - `data_request`, `data_response`, `mutation`
- Add Pydantic models: - Add Pydantic models:
- `WsHomeRequest(type, message, conversation_history?)` - `WsHomeRequest(type, message, conversation_history?)`
- `WsPopupRequest(type, message, scope: {type, id?})` - `WsFloatingRequest(type, message, scope: {type, id?})`
- `WsStreamStart(type, request_id)` - `WsStreamStart(type, request_id)`
- `WsStreamText(type, request_id, chunk)` - `WsStreamText(type, request_id, chunk)`
- `WsStreamBlock(type, request_id, block_type, data)` - `WsStreamBlock(type, request_id, block_type, data)`
- `WsStreamEnd(type, request_id, mutations?)` - `WsStreamEnd(type, request_id, mutations?)`
- `WsPopupDomain(type, request_id, domain)` - `WsFloatingDomain(type, request_id, domain)`
- Keep all existing frame types (backward compat). - Keep all existing frame types (backward compat).
**Files touched**: `app/schemas.py` **Files touched**: `app/schemas.py`
@@ -130,7 +130,7 @@ git commit -m "step-3: add router refactor with streaming support (orchestrator.
## Step 4 — Output Formatting Layer (NEW: output_formatter.py) ## Step 4 — Output Formatting Layer (NEW: output_formatter.py)
**Goal**: Home and Popup responses diverge at this layer only. **Goal**: Home and Floating responses diverge at this layer only.
### Block Types (from Electron app components) ### Block Types (from Electron app components)
@@ -194,14 +194,14 @@ Supported entity types (matching Electron component types):
- `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock` - `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock`
- `timeline` -> buffers, validates checkpoint objects, yields `WsStreamBlock` - `timeline` -> buffers, validates checkpoint objects, yields `WsStreamBlock`
- Invalid blocks are logged and skipped (never crash the stream) - Invalid blocks are logged and skipped (never crash the stream)
- `PopupFormatter`: - `FloatingFormatter`:
- Receives `agent_name` from orchestrator - Receives `agent_name` from orchestrator
- Maps agent name to domain (deterministic, by code — no LLM): - Maps agent name to domain (deterministic, by code — no LLM):
- `task_agent` -> `"tasks"` - `task_agent` -> `"tasks"`
- `checkpoint_agent` -> `"checkpoints"` - `checkpoint_agent` -> `"checkpoints"`
- `note_agent` -> `"notes"` - `note_agent` -> `"notes"`
- `project_agent` -> `"projects"` - `project_agent` -> `"projects"`
- Yields `WsPopupDomain` immediately - Yields `WsFloatingDomain` immediately
- Then yields `WsStreamText` for all tokens (text-only, no blocks) - Then yields `WsStreamText` for all tokens (text-only, no blocks)
**Files touched**: `app/core/output_formatter.py` (new) **Files touched**: `app/core/output_formatter.py` (new)
@@ -223,13 +223,13 @@ git commit -m "step-4: add output formatting layer (output_formatter.py)"
## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py) ## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py)
**Goal**: Single multiplexed WebSocket handles device frames + Home/Popup chat. **Goal**: Single multiplexed WebSocket handles device frames + Home/Floating chat.
**Changes**: **Changes**:
- `app/api/routes/device_ws.py`: - `app/api/routes/device_ws.py`:
- Extend `_message_loop` dispatch to handle `home_request` and `popup_request`: - Extend `_message_loop` dispatch to handle `home_request` and `floating_request`:
- On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket. - On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket.
- On `popup_request`: same, but pipe through `PopupFormatter`. - On `floating_request`: same, but pipe through `FloatingFormatter`.
- Wrap both in try/finally to clear `ws_context`. - Wrap both in try/finally to clear `ws_context`.
- Each request gets a `request_id` (UUID) for frame correlation. - Each request gets a `request_id` (UUID) for frame correlation.
- Concurrent requests from same client are supported (each runs as an async task). - Concurrent requests from same client are supported (each runs as an async task).
@@ -246,7 +246,7 @@ git commit -m "step-4: add output formatting layer (output_formatter.py)"
1. Connects to `/api/v1/ws/device` 1. Connects to `/api/v1/ws/device`
2. Sends `device_hello` 2. Sends `device_hello`
3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end` 3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end`
4. Sends `popup_request` -> receives `popup_domain`, `stream_text`*, `stream_end` 4. Sends `floating_request` -> receives `floating_domain`, `stream_text`*, `stream_end`
5. Verifies `tool_call`/`tool_result` round-trip still works during chat 5. Verifies `tool_call`/`tool_result` round-trip still works during chat
``` ```
pytest tests/test_ws_unified.py pytest tests/test_ws_unified.py
@@ -313,7 +313,7 @@ git commit -m "step-6: add memory models and migration (models.py, alembic)"
3. Embed interaction, encrypt and upsert in `MemoryAssociative` 3. Embed interaction, encrypt and upsert in `MemoryAssociative`
- `update_core(user_id, key, value)` — explicit preference update - `update_core(user_id, key, value)` — explicit preference update
- All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key` - All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key`
- `app/api/routes/device_ws.py` — Update `home_request` and `popup_request` handlers: - `app/api/routes/device_ws.py` — Update `home_request` and `floating_request` handlers:
- Before orchestrator: `enriched = await memory.enrich_context(user_id, message)` - Before orchestrator: `enriched = await memory.enrich_context(user_id, message)`
- After response complete: `await memory.store_episode(user_id, ...)` - After response complete: `await memory.store_episode(user_id, ...)`

View File

@@ -44,7 +44,7 @@ from app.core.agent_runner import trigger_pending_runs
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware from app.core.memory_middleware import MemoryMiddleware
from app.core.orchestrator import orchestrate_v3_stream from app.core.orchestrator import orchestrate_v3_stream
from app.core.output_formatter import HomeFormatter, PopupFormatter from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.core.ws_context import clear_client_executor, set_client_executor from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session from app.db import async_session
from app.models import AgentRunLog from app.models import AgentRunLog
@@ -183,9 +183,9 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_home_request(websocket, user_id, frame) _handle_home_request(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.popup_request: elif frame_type == WsFrameType.floating_request:
asyncio.create_task( asyncio.create_task(
_handle_popup_request(websocket, user_id, frame) _handle_floating_request(websocket, user_id, frame)
) )
elif frame_type == "pong": elif frame_type == "pong":
@@ -257,12 +257,12 @@ async def _handle_home_request(
) )
async def _handle_popup_request( async def _handle_floating_request(
websocket: WebSocket, websocket: WebSocket,
user_id: str, user_id: str,
frame: dict, frame: dict,
) -> None: ) -> None:
"""Handle a popup_request frame — streams PopupFormatter output back on the socket.""" """Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4()) request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "") message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4()) session_id: str = frame.get("session_id") or str(uuid4())
@@ -280,14 +280,14 @@ async def _handle_popup_request(
response_chunks: list[str] = [] response_chunks: list[str] = []
try: try:
token_stream = orchestrate_v3_stream(user_id, message, context) token_stream = orchestrate_v3_stream(user_id, message, context)
formatter = PopupFormatter(request_id=request_id) formatter = FloatingFormatter(request_id=request_id)
async for ws_frame in formatter.format(token_stream): async for ws_frame in formatter.format(token_stream):
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr] if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"device_ws: popup_request failed user=%s req=%s: %s", "device_ws: floating_request failed user=%s req=%s: %s",
user_id, request_id, exc, user_id, request_id, exc,
) )
finally: finally:

View File

@@ -166,7 +166,7 @@ async def orchestrate_v3_stream(
"""v3 streaming orchestration — yields (agent_name, token) pairs. """v3 streaming orchestration — yields (agent_name, token) pairs.
The first yield always carries the agent_name with an empty token so that The first yield always carries the agent_name with an empty token so that
callers (e.g. PopupFormatter) can detect the routing domain before any text callers (e.g. FloatingFormatter) can detect the routing domain before any text
tokens arrive. tokens arrive.
""" """
if reg is None: if reg is None:

View File

@@ -1,7 +1,7 @@
"""Output Formatter — transforms orchestrator token streams into WS frame sequences. """Output Formatter — transforms orchestrator token streams into WS frame sequences.
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
PopupFormatter: produces popup_domain, stream_text, stream_end FloatingFormatter: produces floating_domain, stream_text, stream_end
""" """
from __future__ import annotations from __future__ import annotations
@@ -12,7 +12,7 @@ from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from app.schemas import ( from app.schemas import (
WsPopupDomain, WsFloatingDomain,
WsStreamBlock, WsStreamBlock,
WsStreamEnd, WsStreamEnd,
WsStreamStart, WsStreamStart,
@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron) # Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"} _VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
# Map agent name → popup domain # Map agent name → floating domain
_AGENT_DOMAIN: dict[str, str] = { _AGENT_DOMAIN: dict[str, str] = {
"task_agent": "tasks", "task_agent": "tasks",
"checkpoint_agent": "checkpoints", "checkpoint_agent": "checkpoints",
@@ -32,7 +32,7 @@ _AGENT_DOMAIN: dict[str, str] = {
"project_agent": "projects", "project_agent": "projects",
} }
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsPopupDomain WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
class HomeFormatter: class HomeFormatter:
@@ -191,11 +191,11 @@ class HomeFormatter:
return matches if matches else None return matches if matches else None
class PopupFormatter: class FloatingFormatter:
"""Parses a token stream from orchestrate_v3_stream and yields WS frames. """Parses a token stream from orchestrate_v3_stream and yields WS frames.
Emits popup_domain immediately (from agent_name), then streams all tokens Emits floating_domain immediately (from agent_name), then streams all tokens
as plain stream_text — no block parsing for popup context. as plain stream_text — no block parsing for floating context.
""" """
def __init__(self, request_id: str) -> None: def __init__(self, request_id: str) -> None:
@@ -210,7 +210,7 @@ class PopupFormatter:
async for agent_name, token in token_stream: async for agent_name, token in token_stream:
if not domain_sent: if not domain_sent:
domain = _AGENT_DOMAIN.get(agent_name, "tasks") domain = _AGENT_DOMAIN.get(agent_name, "tasks")
yield WsPopupDomain( yield WsFloatingDomain(
request_id=self.request_id, request_id=self.request_id,
domain=domain, # type: ignore[arg-type] domain=domain, # type: ignore[arg-type]
) )

View File

@@ -174,12 +174,12 @@ class WsFrameType(str, Enum):
device_hello = "device_hello" device_hello = "device_hello"
# ── v3 frame types ───────────────────────────────────────────────── # ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request" home_request = "home_request"
popup_request = "popup_request" floating_request = "floating_request"
stream_start = "stream_start" stream_start = "stream_start"
stream_text = "stream_text" stream_text = "stream_text"
stream_block = "stream_block" stream_block = "stream_block"
stream_end = "stream_end" stream_end = "stream_end"
popup_domain = "popup_domain" floating_domain = "floating_domain"
data_request = "data_request" data_request = "data_request"
data_response = "data_response" data_response = "data_response"
mutation = "mutation" mutation = "mutation"
@@ -263,8 +263,8 @@ class WsAgentComplete(BaseModel):
# ── WebSocket v3 Frame Models ───────────────────────────────────────── # ── WebSocket v3 Frame Models ─────────────────────────────────────────
class WsPopupScope(BaseModel): class WsFloatingScope(BaseModel):
"""Scope for a popup request — narrows the agent to a specific entity.""" """Scope for a floating request — narrows the agent to a specific entity."""
type: Literal["task", "project", "note", "checkpoint"] type: Literal["task", "project", "note", "checkpoint"]
id: str | None = None id: str | None = None
@@ -278,12 +278,12 @@ class WsHomeRequest(BaseModel):
conversation_history: list[dict[str, Any]] = Field(default_factory=list) conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class WsPopupRequest(BaseModel): class WsFloatingRequest(BaseModel):
"""Client → Server: Popup chat message scoped to an entity.""" """Client → Server: Floating chat message scoped to an entity."""
type: Literal[WsFrameType.popup_request] = WsFrameType.popup_request type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str message: str
scope: WsPopupScope scope: WsFloatingScope
class WsStreamStart(BaseModel): class WsStreamStart(BaseModel):
@@ -318,10 +318,10 @@ class WsStreamEnd(BaseModel):
mutations: list[dict[str, Any]] = Field(default_factory=list) mutations: list[dict[str, Any]] = Field(default_factory=list)
class WsPopupDomain(BaseModel): class WsFloatingDomain(BaseModel):
"""Server → Client: domain determined for a popup request.""" """Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.popup_domain] = WsFrameType.popup_domain type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str request_id: str
domain: Literal["tasks", "checkpoints", "notes", "projects"] domain: Literal["tasks", "checkpoints", "notes", "projects"]

View File

@@ -1,12 +1,12 @@
"""Tests for app.core.output_formatter — HomeFormatter and PopupFormatter.""" """Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
from __future__ import annotations from __future__ import annotations
import pytest import pytest
from app.core.output_formatter import HomeFormatter, PopupFormatter from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.schemas import ( from app.schemas import (
WsPopupDomain, WsFloatingDomain,
WsStreamBlock, WsStreamBlock,
WsStreamEnd, WsStreamEnd,
WsStreamStart, WsStreamStart,
@@ -134,12 +134,12 @@ async def test_home_formatter_frame_order():
assert isinstance(frames[-1], WsStreamEnd) assert isinstance(frames[-1], WsStreamEnd)
# ── PopupFormatter ──────────────────────────────────────────────────────────── # ── FloatingFormatter ────────────────────────────────────────────────────────────
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_popup_formatter_domain_emitted_first(): async def test_floating_formatter_domain_emitted_first():
req_id = "pop-1" req_id = "pop-1"
formatter = PopupFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
tokens = [ tokens = [
("task_agent", ""), # domain signal ("task_agent", ""), # domain signal
("task_agent", "Hello"), ("task_agent", "Hello"),
@@ -147,19 +147,19 @@ async def test_popup_formatter_domain_emitted_first():
] ]
frames = await collect(formatter, _stream(*tokens)) frames = await collect(formatter, _stream(*tokens))
assert isinstance(frames[0], WsPopupDomain) assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "tasks" assert frames[0].domain == "tasks"
assert frames[0].request_id == req_id assert frames[0].request_id == req_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_popup_formatter_text_only(): async def test_floating_formatter_text_only():
req_id = "pop-2" req_id = "pop-2"
formatter = PopupFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")] tokens = [("checkpoint_agent", ""), ("checkpoint_agent", "Summary")]
frames = await collect(formatter, _stream(*tokens)) frames = await collect(formatter, _stream(*tokens))
assert isinstance(frames[0], WsPopupDomain) assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "checkpoints" assert frames[0].domain == "checkpoints"
text_frames = [f for f in frames if isinstance(f, WsStreamText)] text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert len(text_frames) == 1 assert len(text_frames) == 1
@@ -167,10 +167,10 @@ async def test_popup_formatter_text_only():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_popup_formatter_no_block_frames(): async def test_floating_formatter_no_block_frames():
"""PopupFormatter must never emit WsStreamBlock.""" """FloatingFormatter must never emit WsStreamBlock."""
req_id = "pop-3" req_id = "pop-3"
formatter = PopupFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
tokens = [ tokens = [
("note_agent", ""), ("note_agent", ""),
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'), ("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
@@ -180,16 +180,16 @@ async def test_popup_formatter_no_block_frames():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_popup_formatter_end_frame(): async def test_floating_formatter_end_frame():
req_id = "pop-4" req_id = "pop-4"
formatter = PopupFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done"))) frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
assert isinstance(frames[-1], WsStreamEnd) assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_popup_formatter_unknown_agent_defaults_to_tasks(): async def test_floating_formatter_unknown_agent_defaults_to_tasks():
req_id = "pop-5" req_id = "pop-5"
formatter = PopupFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi"))) frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
assert frames[0].domain == "tasks" assert frames[0].domain == "tasks"

View File

@@ -6,9 +6,9 @@ from pydantic import ValidationError
from app.schemas import ( from app.schemas import (
WsFrameType, WsFrameType,
WsHomeRequest, WsHomeRequest,
WsPopupDomain, WsFloatingDomain,
WsPopupRequest, WsFloatingRequest,
WsPopupScope, WsFloatingScope,
WsStreamBlock, WsStreamBlock,
WsStreamEnd, WsStreamEnd,
WsStreamStart, WsStreamStart,
@@ -22,12 +22,12 @@ from app.schemas import (
def test_v3_frame_types_exist(): def test_v3_frame_types_exist():
v3_types = [ v3_types = [
"home_request", "home_request",
"popup_request", "floating_request",
"stream_start", "stream_start",
"stream_text", "stream_text",
"stream_block", "stream_block",
"stream_end", "stream_end",
"popup_domain", "floating_domain",
"data_request", "data_request",
"data_response", "data_response",
"mutation", "mutation",
@@ -90,49 +90,49 @@ def test_home_request_requires_message():
WsHomeRequest.model_validate({"type": "home_request"}) WsHomeRequest.model_validate({"type": "home_request"})
# ── WsPopupRequest ──────────────────────────────────────────────────── # ── WsFloatingRequest ────────────────────────────────────────────────────
def test_popup_request_basic(): def test_floating_request_basic():
frame = WsPopupRequest( frame = WsFloatingRequest(
message="Summarise", message="Summarise",
scope=WsPopupScope(type="task", id="task-123"), scope=WsFloatingScope(type="task", id="task-123"),
) )
assert frame.type == WsFrameType.popup_request assert frame.type == WsFrameType.floating_request
assert frame.scope.type == "task" assert frame.scope.type == "task"
assert frame.scope.id == "task-123" assert frame.scope.id == "task-123"
def test_popup_request_scope_without_id(): def test_floating_request_scope_without_id():
frame = WsPopupRequest( frame = WsFloatingRequest(
message="Show all", message="Show all",
scope=WsPopupScope(type="project"), scope=WsFloatingScope(type="project"),
) )
assert frame.scope.id is None assert frame.scope.id is None
def test_popup_request_serializes(): def test_floating_request_serializes():
frame = WsPopupRequest( frame = WsFloatingRequest(
message="Test", message="Test",
scope=WsPopupScope(type="note", id="n-1"), scope=WsFloatingScope(type="note", id="n-1"),
) )
data = frame.model_dump() data = frame.model_dump()
assert data["type"] == "popup_request" assert data["type"] == "floating_request"
assert data["scope"]["type"] == "note" assert data["scope"]["type"] == "note"
assert data["scope"]["id"] == "n-1" assert data["scope"]["id"] == "n-1"
def test_popup_request_invalid_scope_type(): def test_floating_request_invalid_scope_type():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
WsPopupRequest( WsFloatingRequest(
message="X", message="X",
scope=WsPopupScope(type="unknown"), # type: ignore[arg-type] scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
) )
def test_popup_request_requires_scope(): def test_floating_request_requires_scope():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
WsPopupRequest.model_validate({"type": "popup_request", "message": "X"}) WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
# ── WsStreamStart ───────────────────────────────────────────────────── # ── WsStreamStart ─────────────────────────────────────────────────────
@@ -261,32 +261,32 @@ def test_stream_end_deserializes():
assert frame.request_id == "r3" assert frame.request_id == "r3"
# ── WsPopupDomain ───────────────────────────────────────────────────── # ── WsFloatingDomain ─────────────────────────────────────────────────────
def test_popup_domain_tasks(): def test_floating_domain_tasks():
frame = WsPopupDomain(request_id="r1", domain="tasks") frame = WsFloatingDomain(request_id="r1", domain="tasks")
assert frame.type == WsFrameType.popup_domain assert frame.type == WsFrameType.floating_domain
assert frame.domain == "tasks" assert frame.domain == "tasks"
@pytest.mark.parametrize("domain", ["tasks", "checkpoints", "notes", "projects"]) @pytest.mark.parametrize("domain", ["tasks", "checkpoints", "notes", "projects"])
def test_popup_domain_valid_domains(domain: str): def test_floating_domain_valid_domains(domain: str):
frame = WsPopupDomain(request_id="r1", domain=domain) # type: ignore[arg-type] frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
assert frame.domain == domain assert frame.domain == domain
def test_popup_domain_invalid(): def test_floating_domain_invalid():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
WsPopupDomain(request_id="r1", domain="invalid") # type: ignore[arg-type] WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
def test_popup_domain_serializes(): def test_floating_domain_serializes():
d = WsPopupDomain(request_id="r1", domain="notes").model_dump() d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
assert d == {"type": "popup_domain", "request_id": "r1", "domain": "notes"} assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
def test_popup_domain_deserializes(): def test_floating_domain_deserializes():
raw = {"type": "popup_domain", "request_id": "r1", "domain": "projects"} raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
frame = WsPopupDomain.model_validate(raw) frame = WsFloatingDomain.model_validate(raw)
assert frame.domain == "projects" assert frame.domain == "projects"

View File

@@ -1,6 +1,6 @@
"""Integration tests for the unified WebSocket handler (Step 5). """Integration tests for the unified WebSocket handler (Step 5).
Tests the device WS endpoint with home_request and popup_request frames, Tests the device WS endpoint with home_request and floating_request frames,
verifying that the correct v3 frame sequence is returned. verifying that the correct v3 frame sequence is returned.
LLM calls are mocked to avoid network dependency. LLM calls are mocked to avoid network dependency.
@@ -34,7 +34,7 @@ def _override_db(db_session):
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]: def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
"""Receive frames until stream_end (or stream_end inside popup flow), or max_frames.""" """Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
frames = [] frames = []
for _ in range(max_frames): for _ in range(max_frames):
raw = ws.receive_text() raw = ws.receive_text()
@@ -50,7 +50,7 @@ async def _mock_home_stream(user_id, message, context, reg=None):
yield "task_agent", '{"type": "text", "content": "Hello"}' yield "task_agent", '{"type": "text", "content": "Hello"}'
async def _mock_popup_stream(user_id, message, context, reg=None): async def _mock_floating_stream(user_id, message, context, reg=None):
yield "task_agent", "" yield "task_agent", ""
yield "task_agent", "Here is a summary" yield "task_agent", "Here is a summary"
@@ -80,17 +80,17 @@ def test_home_request_produces_stream_frames(client):
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end) assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
def test_popup_request_produces_domain_frame(client): def test_floating_request_produces_domain_frame(client):
"""popup_request → popup_domain first, then stream_text*, stream_end.""" """floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID) token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_popup_stream): with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": [] "type": "device_hello", "device_id": "dev-2", "agent_ids": []
})) }))
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "popup_request", "type": "floating_request",
"request_id": "p1", "request_id": "p1",
"message": "Summarize this task", "message": "Summarize this task",
"scope": {"type": "task", "id": "task-123"}, "scope": {"type": "task", "id": "task-123"},
@@ -98,11 +98,11 @@ def test_popup_request_produces_domain_frame(client):
frames = _recv_until_end(ws) frames = _recv_until_end(ws)
types = [f["type"] for f in frames] types = [f["type"] for f in frames]
assert WsFrameType.popup_domain in types assert WsFrameType.floating_domain in types
assert WsFrameType.stream_end in types assert WsFrameType.stream_end in types
assert types.index(WsFrameType.popup_domain) < types.index(WsFrameType.stream_end) assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.popup_domain) domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"] == "tasks" assert domain_frame["domain"] == "tasks"
assert domain_frame["request_id"] == "p1" assert domain_frame["request_id"] == "p1"