diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py index 709eb90..46b36e1 100644 --- a/app/api/routes/device_ws.py +++ b/app/api/routes/device_ws.py @@ -42,8 +42,9 @@ from sqlalchemy import update from app.api.routes.agent_setup import handle_journey_message, handle_journey_start from app.config.settings import settings from app.core.agent_runner import trigger_pending_runs +from app.core.agent_session_buffer import session_buffer from app.core.brief_agent import run_home_brief, run_project_brief -from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream +from app.core.deep_agent import run_contextual_stream, run_floating_stream, run_home_stream, run_task_brief_research_stream from app.core.output_formatter import extract_canvas_block from app.core.device_manager import device_manager from app.core.memory_middleware import MemoryMiddleware @@ -52,6 +53,7 @@ from app.core.ws_context import clear_client_executor, set_client_executor from app.db import async_session from app.models import AgentRunLog from app.schemas import WsFrameType, WsStreamEnd +from app.schemas.contextual import ContextualScope, render_scope_block logger = logging.getLogger(__name__) @@ -197,6 +199,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None: elif frame_type == WsFrameType.index_session_cancel: await _handle_index_session_cancel(websocket, frame) + elif frame_type == WsFrameType.contextual_request: + asyncio.create_task( + _handle_contextual_request(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.contextual_scope_update: + asyncio.create_task( + _handle_contextual_scope_update(websocket, user_id, frame) + ) + elif frame_type == "pong": # Heartbeat ack — nothing to do, connection is alive. pass @@ -359,6 +371,121 @@ async def _handle_floating_request( ) +# ── v8 Contextual Sidebar Handlers ─────────────────────────────────── + + +def get_session_buffer(session_id: str, channel: str = "contextual"): + """Return the session buffer for the given session. + + The channel kwarg is accepted for forward-compatibility but not used for + namespacing yet (session ids are UUIDs so collisions are negligible). + Defined at module level so tests can monkeypatch it. + """ + return session_buffer + + +async def _handle_contextual_request( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a contextual_request frame — runs the contextual agent and streams frames.""" + request_id = frame.get("request_id") or str(uuid4()) + message: str = frame.get("message", "") + session_id: str = frame.get("session_id") or str(uuid4()) + scope_payload: dict = frame.get("scope", {}) + logger.info( + "device_ws: contextual_request_start user=%s req=%s session=%s msg=%s", + user_id, + request_id, + session_id, + message[:200], + ) + + scope = ContextualScope.model_validate(scope_payload) + + # Enrich context with memory before the LLM call. + async with async_session() as db: + memory = MemoryMiddleware(db) + memory_context = await memory.enrich_context( + user_id, + message, + trace_id=request_id, + session_id=session_id, + ) + + context: dict = { + "conversation_history": frame.get("conversation_history", []), + "format_prefs": frame.get("format_prefs"), + "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, + **memory_context, + } + + executor = await _make_ws_executor(websocket, user_id) + set_client_executor(executor) + response_chunks: list[str] = [] + try: + event_stream = run_contextual_stream( + user_id=user_id, + message=message, + context=context, + scope=scope, + ) + formatter = StreamFormatter(request_id=request_id) + async for ws_frame in formatter.format(event_stream): + await websocket.send_text(ws_frame.model_dump_json()) + if ws_frame.type == "stream_text": # type: ignore[union-attr] + response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] + except Exception as exc: + logger.error( + "device_ws: contextual_request failed user=%s req=%s: %s", + user_id, request_id, exc, + ) + finally: + clear_client_executor() + + # Store episode so the contextual agent can recall prior turns. + async with async_session() as db: + memory = MemoryMiddleware(db) + await memory.store_episode( + user_id, session_id, message, "".join(response_chunks), trace_id=request_id + ) + logger.info( + "device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d", + user_id, + request_id, + session_id, + len("".join(response_chunks)), + ) + + +async def _handle_contextual_scope_update( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a contextual_scope_update frame. + + Injects a synthetic system message into the session buffer so the next + agent turn knows the user navigated. No LLM call is made. + """ + session_id: str = frame.get("session_id") or str(uuid4()) + scope = ContextualScope.model_validate(frame.get("scope", {})) + block = render_scope_block(scope) + buf = get_session_buffer(session_id, channel="contextual") + buf.append_system_message( + f"User navigated to a new view. {block} Treat this as the new active context." + ) + await websocket.send_text(json.dumps({ + "type": WsFrameType.contextual_scope_ack, + "session_id": session_id, + })) + logger.info( + "device_ws: contextual_scope_update user=%s session=%s page=%s", + user_id, session_id, scope.page, + ) + + async def _handle_brief_request( websocket: WebSocket, user_id: str, diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index ba4d283..e372c5e 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -96,6 +96,10 @@ class WsFrameType(str, Enum): index_file_result = "index_file_result" index_session_progress = "index_session_progress" index_session_done = "index_session_done" + # ── v8 contextual sidebar frame types ──────────────────────────── + contextual_request = "contextual_request" + contextual_scope_update = "contextual_scope_update" + contextual_scope_ack = "contextual_scope_ack" class WsToolCall(BaseModel): diff --git a/tests/test_contextual_ws.py b/tests/test_contextual_ws.py new file mode 100644 index 0000000..01f3b25 --- /dev/null +++ b/tests/test_contextual_ws.py @@ -0,0 +1,44 @@ +"""Tests for contextual WS frame handlers. + +These tests only exercise the new handler functions in device_ws.py and do +not depend on litellm or the full deep_agent import chain. They monkeypatch +run_contextual_stream so no LLM call is made. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +@pytest.mark.asyncio +async def test_handle_contextual_scope_update_appends_system_message_no_llm(monkeypatch): + """_handle_contextual_scope_update must: + - call append_system_message on the session buffer + - send a contextual_scope_ack back on the socket + - make no LLM call + """ + from app.api.routes import device_ws + + ws = AsyncMock() + buffer = MagicMock() + buffer.append_system_message = MagicMock() + + payload = { + "type": "contextual_scope_update", + "session_id": "s1", + "scope": { + "page": "project", + "entityType": "project", + "entityId": "p1", + "entityName": "Acme", + "counts": {"tasks": 1, "notes": 0, "milestones": 0}, + }, + } + + monkeypatch.setattr(device_ws, "get_session_buffer", lambda *a, **kw: buffer) + await device_ws._handle_contextual_scope_update(ws, "user1", payload) + + ws.send_text.assert_awaited_once() + import json + sent = json.loads(ws.send_text.await_args.args[0]) + assert sent["type"] == "contextual_scope_ack" + assert sent["session_id"] == "s1" + buffer.append_system_message.assert_called_once()