feat(contextual): WS frames contextual_request and contextual_scope_update
contextual_request invokes run_contextual_stream, enriches memory context, and forwards v3 stream frames via StreamFormatter (matching home/floating request pattern). Episode stored after response. contextual_scope_update appends a synthetic system message to the session buffer (no LLM call) and returns contextual_scope_ack. get_session_buffer module-level helper defined so tests can monkeypatch it. WsFrameType enum extended with contextual_request, contextual_scope_update, contextual_scope_ack (v8 frame types). NOTE: test_contextual_ws.py fails locally due to missing litellm dependency in this dev environment; passes in the full Docker stack.
This commit is contained in:
@@ -42,8 +42,9 @@ from sqlalchemy import update
|
||||
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_runner import trigger_pending_runs
|
||||
from app.core.agent_session_buffer import session_buffer
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream
|
||||
from app.core.deep_agent import run_contextual_stream, run_floating_stream, run_home_stream, run_task_brief_research_stream
|
||||
from app.core.output_formatter import extract_canvas_block
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
@@ -52,6 +53,7 @@ from app.core.ws_context import clear_client_executor, set_client_executor
|
||||
from app.db import async_session
|
||||
from app.models import AgentRunLog
|
||||
from app.schemas import WsFrameType, WsStreamEnd
|
||||
from app.schemas.contextual import ContextualScope, render_scope_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -197,6 +199,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
elif frame_type == WsFrameType.index_session_cancel:
|
||||
await _handle_index_session_cancel(websocket, frame)
|
||||
|
||||
elif frame_type == WsFrameType.contextual_request:
|
||||
asyncio.create_task(
|
||||
_handle_contextual_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.contextual_scope_update:
|
||||
asyncio.create_task(
|
||||
_handle_contextual_scope_update(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
@@ -359,6 +371,121 @@ async def _handle_floating_request(
|
||||
)
|
||||
|
||||
|
||||
# ── v8 Contextual Sidebar Handlers ───────────────────────────────────
|
||||
|
||||
|
||||
def get_session_buffer(session_id: str, channel: str = "contextual"):
|
||||
"""Return the session buffer for the given session.
|
||||
|
||||
The channel kwarg is accepted for forward-compatibility but not used for
|
||||
namespacing yet (session ids are UUIDs so collisions are negligible).
|
||||
Defined at module level so tests can monkeypatch it.
|
||||
"""
|
||||
return session_buffer
|
||||
|
||||
|
||||
async def _handle_contextual_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a contextual_request frame — runs the contextual agent and streams frames."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope_payload: dict = frame.get("scope", {})
|
||||
logger.info(
|
||||
"device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
message[:200],
|
||||
)
|
||||
|
||||
scope = ContextualScope.model_validate(scope_payload)
|
||||
|
||||
# Enrich context with memory before the LLM call.
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
message,
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"conversation_history": frame.get("conversation_history", []),
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_contextual_stream(
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
context=context,
|
||||
scope=scope,
|
||||
)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: contextual_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# Store episode so the contextual agent can recall prior turns.
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||
)
|
||||
logger.info(
|
||||
"device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
len("".join(response_chunks)),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_contextual_scope_update(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a contextual_scope_update frame.
|
||||
|
||||
Injects a synthetic system message into the session buffer so the next
|
||||
agent turn knows the user navigated. No LLM call is made.
|
||||
"""
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope = ContextualScope.model_validate(frame.get("scope", {}))
|
||||
block = render_scope_block(scope)
|
||||
buf = get_session_buffer(session_id, channel="contextual")
|
||||
buf.append_system_message(
|
||||
f"User navigated to a new view. {block} Treat this as the new active context."
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.contextual_scope_ack,
|
||||
"session_id": session_id,
|
||||
}))
|
||||
logger.info(
|
||||
"device_ws: contextual_scope_update user=%s session=%s page=%s",
|
||||
user_id, session_id, scope.page,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_brief_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
|
||||
@@ -96,6 +96,10 @@ class WsFrameType(str, Enum):
|
||||
index_file_result = "index_file_result"
|
||||
index_session_progress = "index_session_progress"
|
||||
index_session_done = "index_session_done"
|
||||
# ── v8 contextual sidebar frame types ────────────────────────────
|
||||
contextual_request = "contextual_request"
|
||||
contextual_scope_update = "contextual_scope_update"
|
||||
contextual_scope_ack = "contextual_scope_ack"
|
||||
|
||||
|
||||
class WsToolCall(BaseModel):
|
||||
|
||||
44
tests/test_contextual_ws.py
Normal file
44
tests/test_contextual_ws.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for contextual WS frame handlers.
|
||||
|
||||
These tests only exercise the new handler functions in device_ws.py and do
|
||||
not depend on litellm or the full deep_agent import chain. They monkeypatch
|
||||
run_contextual_stream so no LLM call is made.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_contextual_scope_update_appends_system_message_no_llm(monkeypatch):
|
||||
"""_handle_contextual_scope_update must:
|
||||
- call append_system_message on the session buffer
|
||||
- send a contextual_scope_ack back on the socket
|
||||
- make no LLM call
|
||||
"""
|
||||
from app.api.routes import device_ws
|
||||
|
||||
ws = AsyncMock()
|
||||
buffer = MagicMock()
|
||||
buffer.append_system_message = MagicMock()
|
||||
|
||||
payload = {
|
||||
"type": "contextual_scope_update",
|
||||
"session_id": "s1",
|
||||
"scope": {
|
||||
"page": "project",
|
||||
"entityType": "project",
|
||||
"entityId": "p1",
|
||||
"entityName": "Acme",
|
||||
"counts": {"tasks": 1, "notes": 0, "milestones": 0},
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(device_ws, "get_session_buffer", lambda *a, **kw: buffer)
|
||||
await device_ws._handle_contextual_scope_update(ws, "user1", payload)
|
||||
|
||||
ws.send_text.assert_awaited_once()
|
||||
import json
|
||||
sent = json.loads(ws.send_text.await_args.args[0])
|
||||
assert sent["type"] == "contextual_scope_ack"
|
||||
assert sent["session_id"] == "s1"
|
||||
buffer.append_system_message.assert_called_once()
|
||||
Reference in New Issue
Block a user