_SessionBuffer.append_system_message(user_id, session_id, text) injects a synthetic SystemMessage into the named session slot (creating it if absent). ContextualBufferProxy closes over user_id + session_id so call sites need only call proxy.append_system_message(text). get_session_buffer(user_id, session_id, channel) in device_ws returns a ContextualBufferProxy, keeping the test-patchable function signature intact.
921 lines
34 KiB
Python
921 lines
34 KiB
Python
"""Device WebSocket endpoint.
|
|
|
|
Persistent connection from Electron devices to the backend.
|
|
|
|
WS /api/v1/ws/device?token=<jwt>
|
|
|
|
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
|
available during the WebSocket handshake).
|
|
|
|
Protocol:
|
|
1. Client connects → JWT validated → connection accepted.
|
|
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
|
3. Backend registers the connection in ``DeviceConnectionManager``.
|
|
4. Session enters message dispatch loop + heartbeat.
|
|
|
|
Incoming frame dispatch:
|
|
- ``tool_result`` → resolves a pending tool-call Future.
|
|
- ``journey_start`` → starts a guided setup journey session.
|
|
- ``journey_message`` → continues a journey conversation.
|
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
|
- unknown types → logged, ignored.
|
|
|
|
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
|
|
|
On disconnect:
|
|
- Unregisters from DeviceConnectionManager.
|
|
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
|
with message "device disconnected".
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from uuid import uuid4
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
from jose import JWTError, jwt
|
|
from sqlalchemy import update
|
|
|
|
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
|
from app.config.settings import settings
|
|
from app.core.agent_runner import trigger_pending_runs
|
|
from app.core.agent_session_buffer import session_buffer
|
|
from app.core.brief_agent import run_home_brief, run_project_brief
|
|
from app.core.deep_agent import run_contextual_stream, run_floating_stream, run_home_stream, run_task_brief_research_stream
|
|
from app.core.output_formatter import extract_canvas_block
|
|
from app.core.device_manager import device_manager
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.core.output_formatter import StreamFormatter
|
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
|
from app.db import async_session
|
|
from app.models import AgentRunLog
|
|
from app.schemas import WsFrameType, WsStreamEnd
|
|
from app.schemas.contextual import ContextualScope, render_scope_block
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
|
|
|
# ── v7 folder index session state ─────────────────────────────────────
|
|
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
|
_index_sessions: dict[str, dict] = {}
|
|
|
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
|
|
|
|
|
@router.websocket("/device")
|
|
async def device_ws(websocket: WebSocket) -> None:
|
|
"""Persistent WebSocket endpoint for Electron device connections.
|
|
|
|
Authentication is via ``?token=<jwt>`` query parameter.
|
|
"""
|
|
# ── 1. Authenticate before accepting ─────────────────────────────
|
|
token = websocket.query_params.get("token", "")
|
|
try:
|
|
payload = jwt.decode(
|
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
user_id: str | None = payload.get("sub")
|
|
if not user_id:
|
|
raise JWTError("missing sub")
|
|
except JWTError:
|
|
await websocket.close(code=1008) # Policy Violation
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
# ── 2. Await device_hello frame ───────────────────────────────────
|
|
try:
|
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
|
except (asyncio.TimeoutError, WebSocketDisconnect):
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
try:
|
|
hello = json.loads(raw)
|
|
if hello.get("type") != WsFrameType.device_hello:
|
|
raise ValueError("expected device_hello as first frame")
|
|
device_id: str = hello["device_id"]
|
|
agent_ids: list[str] = hello.get("agent_ids", [])
|
|
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
|
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
# ── 3. Register connection ────────────────────────────────────────
|
|
device_manager.register(user_id, device_id, websocket)
|
|
logger.info(
|
|
"device_ws: connected user=%s device=%s agents=%s",
|
|
user_id,
|
|
device_id,
|
|
agent_ids,
|
|
)
|
|
|
|
# Trigger any overdue agent runs now that the device is connected.
|
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
|
|
|
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
|
try:
|
|
await asyncio.gather(
|
|
_message_loop(websocket, user_id),
|
|
_heartbeat_loop(websocket),
|
|
)
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as exc:
|
|
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
|
finally:
|
|
device_manager.unregister(user_id)
|
|
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
|
await _mark_runs_disconnected(user_id)
|
|
|
|
|
|
# ── Message dispatch loop ─────────────────────────────────────────────
|
|
|
|
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
|
async for raw in websocket.iter_text():
|
|
try:
|
|
frame: dict = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
|
continue
|
|
|
|
frame_type = frame.get("type")
|
|
|
|
if frame_type == WsFrameType.tool_result:
|
|
call_id = frame.get("id")
|
|
if call_id:
|
|
device_manager.resolve_pending_call(user_id, call_id, frame)
|
|
else:
|
|
logger.warning(
|
|
"device_ws: tool_result missing id from user=%s", user_id
|
|
)
|
|
|
|
elif frame_type == WsFrameType.home_request:
|
|
asyncio.create_task(
|
|
_handle_home_request(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.floating_request:
|
|
asyncio.create_task(
|
|
_handle_floating_request(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.brief_request:
|
|
asyncio.create_task(
|
|
_handle_brief_request(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.task_brief_request:
|
|
asyncio.create_task(
|
|
_handle_task_brief_request(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.journey_start:
|
|
asyncio.create_task(
|
|
_handle_journey_start(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.journey_message:
|
|
asyncio.create_task(
|
|
_handle_journey_message(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.index_session_start:
|
|
asyncio.create_task(
|
|
_handle_index_session_start(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.index_file_batch:
|
|
asyncio.create_task(
|
|
_handle_index_file_batch(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.index_session_cancel:
|
|
await _handle_index_session_cancel(websocket, frame)
|
|
|
|
elif frame_type == WsFrameType.contextual_request:
|
|
asyncio.create_task(
|
|
_handle_contextual_request(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == WsFrameType.contextual_scope_update:
|
|
asyncio.create_task(
|
|
_handle_contextual_scope_update(websocket, user_id, frame)
|
|
)
|
|
|
|
elif frame_type == "pong":
|
|
# Heartbeat ack — nothing to do, connection is alive.
|
|
pass
|
|
|
|
else:
|
|
logger.debug(
|
|
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
|
)
|
|
|
|
|
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
|
|
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
|
async def _executor(payload: dict) -> dict:
|
|
payload["type"] = WsFrameType.tool_call
|
|
await websocket.send_text(json.dumps(payload))
|
|
future = device_manager.create_pending_call(user_id, payload["id"])
|
|
return await future
|
|
return _executor
|
|
|
|
|
|
async def _handle_home_request(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
message: str = frame.get("message", "")
|
|
session_id: str = frame.get("session_id") or str(uuid4())
|
|
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
|
logger.info(
|
|
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
project_id,
|
|
message[:200],
|
|
)
|
|
|
|
# ── Memory: enrich context before LLM call ────────────────────────
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id,
|
|
message,
|
|
trace_id=request_id,
|
|
session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"conversation_history": frame.get("conversation_history", []),
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
"format_prefs": frame.get("format_prefs"),
|
|
**memory_context,
|
|
}
|
|
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
response_chunks: list[str] = []
|
|
try:
|
|
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
await websocket.send_text(ws_frame.model_dump_json())
|
|
# Collect text chunks to build the full response for episode storage
|
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: home_request failed user=%s req=%s: %s",
|
|
user_id, request_id, exc,
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
# ── Memory: store episode after response ──────────────────────────
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.store_episode(
|
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
|
)
|
|
logger.info(
|
|
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
len("".join(response_chunks)),
|
|
)
|
|
|
|
|
|
async def _handle_floating_request(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
message: str = frame.get("message", "")
|
|
session_id: str = frame.get("session_id") or str(uuid4())
|
|
scope: dict = frame.get("scope", {})
|
|
logger.info(
|
|
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
json.dumps(scope, ensure_ascii=True)[:200],
|
|
message[:200],
|
|
)
|
|
|
|
# ── Memory: enrich context before LLM call ────────────────────────
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id,
|
|
message,
|
|
trace_id=request_id,
|
|
session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"conversation_history": frame.get("conversation_history", []),
|
|
"scope": scope,
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
"format_prefs": frame.get("format_prefs"),
|
|
**memory_context,
|
|
}
|
|
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
response_chunks: list[str] = []
|
|
try:
|
|
event_stream = run_floating_stream(user_id, message, context)
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
await websocket.send_text(ws_frame.model_dump_json())
|
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: floating_request failed user=%s req=%s: %s",
|
|
user_id, request_id, exc,
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
# ── Memory: store episode after response ──────────────────────────
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.store_episode(
|
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
|
)
|
|
logger.info(
|
|
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
len("".join(response_chunks)),
|
|
)
|
|
|
|
|
|
# ── v8 Contextual Sidebar Handlers ───────────────────────────────────
|
|
|
|
|
|
def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"):
|
|
"""Return a session-scoped buffer proxy for the given user+session.
|
|
|
|
Returns a _ContextualBufferProxy that exposes append_system_message().
|
|
Defined at module level so tests can monkeypatch it.
|
|
The channel kwarg is accepted for forward-compatibility.
|
|
"""
|
|
from app.core.agent_session_buffer import ContextualBufferProxy # noqa: PLC0415
|
|
return ContextualBufferProxy(session_buffer, user_id, session_id)
|
|
|
|
|
|
async def _handle_contextual_request(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a contextual_request frame — runs the contextual agent and streams frames."""
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
message: str = frame.get("message", "")
|
|
session_id: str = frame.get("session_id") or str(uuid4())
|
|
scope_payload: dict = frame.get("scope", {})
|
|
logger.info(
|
|
"device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
message[:200],
|
|
)
|
|
|
|
scope = ContextualScope.model_validate(scope_payload)
|
|
|
|
# Enrich context with memory before the LLM call.
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id,
|
|
message,
|
|
trace_id=request_id,
|
|
session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"conversation_history": frame.get("conversation_history", []),
|
|
"format_prefs": frame.get("format_prefs"),
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
**memory_context,
|
|
}
|
|
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
response_chunks: list[str] = []
|
|
try:
|
|
event_stream = run_contextual_stream(
|
|
user_id=user_id,
|
|
message=message,
|
|
context=context,
|
|
scope=scope,
|
|
)
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
await websocket.send_text(ws_frame.model_dump_json())
|
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: contextual_request failed user=%s req=%s: %s",
|
|
user_id, request_id, exc,
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
# Store episode so the contextual agent can recall prior turns.
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.store_episode(
|
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
|
)
|
|
logger.info(
|
|
"device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
|
|
user_id,
|
|
request_id,
|
|
session_id,
|
|
len("".join(response_chunks)),
|
|
)
|
|
|
|
|
|
async def _handle_contextual_scope_update(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a contextual_scope_update frame.
|
|
|
|
Injects a synthetic system message into the session buffer so the next
|
|
agent turn knows the user navigated. No LLM call is made.
|
|
"""
|
|
session_id: str = frame.get("session_id") or str(uuid4())
|
|
scope = ContextualScope.model_validate(frame.get("scope", {}))
|
|
block = render_scope_block(scope)
|
|
buf = get_session_buffer(user_id, session_id, channel="contextual")
|
|
buf.append_system_message(
|
|
f"User navigated to a new view. {block} Treat this as the new active context."
|
|
)
|
|
await websocket.send_text(json.dumps({
|
|
"type": WsFrameType.contextual_scope_ack,
|
|
"session_id": session_id,
|
|
}))
|
|
logger.info(
|
|
"device_ws: contextual_scope_update user=%s session=%s page=%s",
|
|
user_id, session_id, scope.page,
|
|
)
|
|
|
|
|
|
async def _handle_brief_request(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a brief_request frame — streams plain-text brief back on the socket.
|
|
|
|
No episode storage — briefs are not conversations.
|
|
"""
|
|
import uuid as _uuid
|
|
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
session_id = frame.get("session_id") or str(uuid4())
|
|
mode: str = frame.get("mode", "home")
|
|
project_id: str | None = frame.get("project_id")
|
|
|
|
logger.info(
|
|
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
|
|
user_id, request_id, mode, project_id,
|
|
)
|
|
|
|
# Validate project_id for project mode before touching LLM.
|
|
if mode == "project":
|
|
try:
|
|
if not project_id:
|
|
raise ValueError("project_id required for project mode")
|
|
_uuid.UUID(project_id)
|
|
except (ValueError, AttributeError) as exc:
|
|
logger.warning(
|
|
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
|
|
user_id, request_id, exc,
|
|
)
|
|
await websocket.send_text(
|
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
|
)
|
|
return
|
|
|
|
# Enrich context with memory (no user message — use empty string as probe).
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id,
|
|
"",
|
|
trace_id=request_id,
|
|
session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
"format_prefs": frame.get("format_prefs"),
|
|
**memory_context,
|
|
}
|
|
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
try:
|
|
if mode == "project":
|
|
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
|
|
else:
|
|
event_stream = run_home_brief(user_id, context)
|
|
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
await websocket.send_text(ws_frame.model_dump_json())
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: brief_request failed user=%s req=%s: %s",
|
|
user_id, request_id, exc,
|
|
)
|
|
await websocket.send_text(
|
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
logger.info(
|
|
"device_ws: brief_request_end user=%s req=%s mode=%s",
|
|
user_id, request_id, mode,
|
|
)
|
|
|
|
|
|
# ── v6 Task Brief Handler ────────────────────────────────────────────
|
|
|
|
|
|
async def _handle_task_brief_request(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
|
|
|
|
Streams the briefing markdown back to the client.
|
|
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
|
|
"""
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
session_id = frame.get("session_id") or str(uuid4())
|
|
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
|
|
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
|
|
|
logger.info(
|
|
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
|
|
user_id, request_id, task_id, project_id,
|
|
)
|
|
|
|
if not task_id:
|
|
await websocket.send_text(
|
|
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
|
|
)
|
|
return
|
|
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id,
|
|
f"task brief: {task_id}",
|
|
trace_id=request_id,
|
|
session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
"format_prefs": frame.get("format_prefs"),
|
|
**memory_context,
|
|
}
|
|
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
response_chunks: list[str] = []
|
|
|
|
try:
|
|
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
|
await websocket.send_text(ws_frame.model_dump_json())
|
|
elif ws_frame.type == "stream_start":
|
|
await websocket.send_text(ws_frame.model_dump_json())
|
|
# stream_end is emitted below with mutations — skip formatter's version
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
|
|
user_id, request_id, task_id, exc,
|
|
)
|
|
await websocket.send_text(
|
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
|
)
|
|
return
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
# Extract canvas block then emit stream_end with optional mutations.
|
|
full_response = "".join(response_chunks)
|
|
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
|
|
|
|
mutations: list[dict] = []
|
|
if canvas_content:
|
|
mutations.append({
|
|
"type": "canvas_draft",
|
|
"content": canvas_content,
|
|
"kind": canvas_kind,
|
|
})
|
|
|
|
await websocket.send_text(
|
|
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
|
|
)
|
|
|
|
logger.info(
|
|
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
|
|
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
|
|
)
|
|
|
|
|
|
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
|
|
|
|
|
async def _handle_journey_start(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a journey_start frame — explores directory and sends first question."""
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
try:
|
|
reply = await handle_journey_start(user_id, frame)
|
|
await websocket.send_text(json.dumps(reply))
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
|
)
|
|
await websocket.send_text(json.dumps({
|
|
"type": "journey_reply",
|
|
"session_id": frame.get("session_id", ""),
|
|
"message": f"Failed to start journey: {exc}",
|
|
"done": True,
|
|
"prompt_template": None,
|
|
}))
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
|
|
async def _handle_journey_message(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Handle a journey_message frame — continues the journey conversation."""
|
|
executor = await _make_ws_executor(websocket, user_id)
|
|
set_client_executor(executor)
|
|
try:
|
|
reply = await handle_journey_message(user_id, frame)
|
|
await websocket.send_text(json.dumps(reply))
|
|
except Exception as exc:
|
|
session_id = frame.get("session_id", "")
|
|
logger.error(
|
|
"device_ws: journey_message failed user=%s session=%s: %s",
|
|
user_id, session_id, exc,
|
|
)
|
|
await websocket.send_text(json.dumps({
|
|
"type": "journey_reply",
|
|
"session_id": session_id,
|
|
"message": f"Journey error: {exc}",
|
|
"done": True,
|
|
"prompt_template": None,
|
|
}))
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
|
|
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
|
|
|
|
|
async def _handle_index_session_start(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Register a new folder index session. No response sent — client is declaring intent."""
|
|
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
|
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
|
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
|
|
|
|
if not session_id:
|
|
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
|
return
|
|
|
|
_index_sessions[session_id] = {
|
|
"user_id": user_id,
|
|
"project_id": project_id,
|
|
"processed": 0,
|
|
"total": total,
|
|
"cancelled": False,
|
|
}
|
|
logger.info(
|
|
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
|
user_id, session_id, project_id, total,
|
|
)
|
|
|
|
|
|
async def _handle_index_session_cancel(
|
|
websocket: WebSocket,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
|
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
|
session = _index_sessions.get(session_id)
|
|
if session:
|
|
session["cancelled"] = True
|
|
|
|
await websocket.send_text(json.dumps({
|
|
"type": WsFrameType.index_session_done,
|
|
"sessionId": session_id,
|
|
"status": "cancelled",
|
|
}))
|
|
_index_sessions.pop(session_id, None)
|
|
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
|
|
|
|
|
async def _handle_index_file_batch(
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
frame: dict,
|
|
) -> None:
|
|
"""Process a batch of files for an index session, streaming results back."""
|
|
# Lazy imports to avoid heavy load at module startup.
|
|
from app.core.folder_indexer import ( # noqa: PLC0415
|
|
summarize_image,
|
|
summarize_pdf,
|
|
summarize_docx,
|
|
summarize_text,
|
|
)
|
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
from app.billing.quota import add_token_usage # noqa: PLC0415
|
|
|
|
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
|
files: list[dict] = frame.get("files", [])
|
|
|
|
session = _index_sessions.get(session_id)
|
|
if not session or session.get("cancelled"):
|
|
return
|
|
|
|
async with async_session() as db:
|
|
tier = await tier_manager.get_tier(user_id, db)
|
|
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
|
cap: int | None = None if raw_cap == -1 else raw_cap
|
|
|
|
for file_info in files:
|
|
if session.get("cancelled"):
|
|
return
|
|
|
|
# Electron's toSnakeCase converts payload keys, so accept both forms.
|
|
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
|
|
kind: str = file_info.get("kind") or "text"
|
|
content: str = file_info.get("content") or ""
|
|
ext: str = file_info.get("ext") or ""
|
|
mime: str = file_info.get("mime") or "application/octet-stream"
|
|
name: str = rel_path.split("/")[-1] or rel_path
|
|
|
|
try:
|
|
if kind == "image":
|
|
res = await summarize_image(image_b64=content, mime=mime)
|
|
elif kind == "pdf":
|
|
res = await summarize_pdf(pdf_b64=content, name=name)
|
|
elif kind == "docx":
|
|
res = await summarize_docx(docx_b64=content, name=name)
|
|
else:
|
|
res = await summarize_text(content=content, ext=ext, name=name)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
|
session_id, rel_path, exc,
|
|
)
|
|
await websocket.send_text(json.dumps({
|
|
"type": WsFrameType.index_file_result,
|
|
"sessionId": session_id,
|
|
"relPath": rel_path,
|
|
"summary": None,
|
|
"tokensUsed": 0,
|
|
"error": str(exc),
|
|
}))
|
|
session["processed"] += 1
|
|
continue
|
|
|
|
# Account for token usage and check cap.
|
|
usage = await add_token_usage(
|
|
user_id=user_id,
|
|
feature="folder_index",
|
|
tokens=res.tokens_used,
|
|
db=db,
|
|
cap=cap,
|
|
)
|
|
|
|
await websocket.send_text(json.dumps({
|
|
"type": WsFrameType.index_file_result,
|
|
"sessionId": session_id,
|
|
"relPath": rel_path,
|
|
"summary": res.summary,
|
|
"tokensUsed": res.tokens_used,
|
|
}))
|
|
session["processed"] += 1
|
|
|
|
if usage.exhausted:
|
|
await websocket.send_text(json.dumps({
|
|
"type": WsFrameType.index_session_done,
|
|
"sessionId": session_id,
|
|
"status": "quota_exceeded",
|
|
}))
|
|
_index_sessions.pop(session_id, None)
|
|
logger.info(
|
|
"device_ws: index_session quota_exceeded user=%s session=%s",
|
|
user_id, session_id,
|
|
)
|
|
return
|
|
|
|
# After processing the batch, emit progress.
|
|
processed = session["processed"]
|
|
total = session["total"]
|
|
await websocket.send_text(json.dumps({
|
|
"type": WsFrameType.index_session_progress,
|
|
"sessionId": session_id,
|
|
"processed": processed,
|
|
"total": total,
|
|
}))
|
|
|
|
if processed >= total:
|
|
await websocket.send_text(json.dumps({
|
|
"type": WsFrameType.index_session_done,
|
|
"sessionId": session_id,
|
|
"status": "completed",
|
|
}))
|
|
_index_sessions.pop(session_id, None)
|
|
logger.info(
|
|
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
|
user_id, session_id, processed,
|
|
)
|
|
|
|
|
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
|
|
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
|
"""Send a ping frame every 30 s to keep the connection alive."""
|
|
while True:
|
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
await websocket.send_text(json.dumps({"type": "ping"}))
|
|
|
|
|
|
# ── Disconnect cleanup ────────────────────────────────────────────────
|
|
|
|
async def _mark_runs_disconnected(user_id: str) -> None:
|
|
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
|
try:
|
|
async with async_session() as db:
|
|
await db.execute(
|
|
update(AgentRunLog)
|
|
.where(
|
|
AgentRunLog.user_id == user_id,
|
|
AgentRunLog.status == "running",
|
|
)
|
|
.values(
|
|
status="error",
|
|
errors=["device disconnected"],
|
|
)
|
|
)
|
|
await db.commit()
|
|
except Exception as exc:
|
|
logger.error(
|
|
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
|
user_id,
|
|
exc,
|
|
)
|