"""WebSocket context for Batch Agent Service — Redis-based tool call round-trip. Same pattern as services/chat/app/ws_context.py: publishes tool_call frames to Redis ws:out:{user_id} and awaits BRPOP on tool:result:{call_id}. Additionally provides set_client_executor / clear_client_executor stubs for backward compatibility with the agent_runner code (which originally used a DeviceConnectionManager callback). In the microservice world these are no-ops — execute_on_client() always uses the Redis path. """ from __future__ import annotations import json import logging from contextvars import ContextVar from typing import Any, Callable, Coroutine from uuid import uuid4 from shared.redis import redis_client, tool_result_key, ws_out_channel logger = logging.getLogger(__name__) _TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout # Per-request user_id context var (set before agent run) _current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None) # Optional collector for debug / logging _tool_result_collector: ContextVar[list[dict] | None] = ContextVar( "_tool_result_collector", default=None ) def set_current_user(user_id: str) -> None: _current_user_id.set(user_id) def clear_current_user() -> None: _current_user_id.set(None) def set_tool_result_collector(lst: list[dict]) -> None: _tool_result_collector.set(lst) def clear_tool_result_collector() -> None: _tool_result_collector.set(None) # ── Compatibility shims ────────────────────────────────────────────────── # agent_runner.py originally called set_client_executor / clear_client_executor # with a DeviceConnectionManager callback. In the microservice world the # Redis-based execute_on_client replaces this, so these are no-ops that # keep the agent_runner code unchanged. def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]] | None) -> None: """No-op — kept for agent_runner compatibility.""" pass def clear_client_executor() -> None: """No-op — kept for agent_runner compatibility.""" pass async def execute_on_client( action: str, table: str | None = None, data: dict[str, Any] | None = None, filters: dict[str, Any] | None = None, vector: list[float] | None = None, limit: int | None = None, ) -> dict[str, Any]: """Send a tool_call to Electron via Redis and await the result. 1. Build tool_call payload 2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron) 3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies) 4. Return result dict Raises RuntimeError if no user_id is set or if the call times out. """ user_id = _current_user_id.get() if not user_id: raise RuntimeError( "execute_on_client() called without a user_id — " "set_current_user() must be called first." ) call_id = str(uuid4()) payload: dict[str, Any] = { "type": "tool_call", "id": call_id, "action": action, } if table is not None: payload["table"] = table if data is not None: payload["data"] = data if filters is not None: payload["filters"] = {k: v for k, v in filters.items() if v is not None} if vector is not None: payload["vector"] = vector if limit is not None: payload["limit"] = limit # Publish tool_call to WS Gateway → Electron channel = ws_out_channel(user_id) await redis_client.publish(channel, json.dumps(payload)) # Wait for Electron's tool_result result_key = tool_result_key(call_id) response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT) if response is None: raise RuntimeError( f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — " f"device may be offline or unresponsive." ) # response is (key, value) tuple _, raw = response result = json.loads(raw) # Collect for debug if requested collector = _tool_result_collector.get(None) if collector is not None: collector.append({ "action": action, "table": table, "data": result, }) return result