116 lines
3.8 KiB
Python
116 lines
3.8 KiB
Python
"""WebSocket client executor context.
|
|
|
|
Holds a per-request async callback that tools call to execute CRUD
|
|
operations on the Electron client's local SQLite / LanceDB databases.
|
|
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from contextvars import ContextVar
|
|
from typing import Any, Callable, Coroutine
|
|
from uuid import uuid4
|
|
|
|
_SNAKE_TO_CAMEL_RE = re.compile(r"_([a-z])")
|
|
|
|
|
|
def _key_to_camel(key: str) -> str:
|
|
return _SNAKE_TO_CAMEL_RE.sub(lambda m: m.group(1).upper(), key)
|
|
|
|
|
|
def _keys_to_camel(obj: Any) -> Any:
|
|
"""Recursively convert dict keys from snake_case to camelCase.
|
|
|
|
Mirrors the JS-side ``toCamelCase`` applied to incoming WS frames in
|
|
``adiuvAI/src/main/api/backend-client.ts``. The Electron executor wraps
|
|
tool_result payloads in ``toSnakeCase`` before sending; this restores the
|
|
camelCase schema property names that the tool code expects to read.
|
|
"""
|
|
if isinstance(obj, dict):
|
|
return {_key_to_camel(k): _keys_to_camel(v) for k, v in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [_keys_to_camel(v) for v in obj]
|
|
return obj
|
|
|
|
# Holds the execute callback for the current WS session.
|
|
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
|
"_client_executor"
|
|
)
|
|
|
|
# Optional collector that captures raw execute_on_client results.
|
|
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
|
"_tool_result_collector", default=None
|
|
)
|
|
|
|
|
|
def set_tool_result_collector(lst: list[dict]) -> None:
|
|
"""Register *lst* as the collector for this async context."""
|
|
_tool_result_collector.set(lst)
|
|
|
|
|
|
def clear_tool_result_collector() -> None:
|
|
"""Clear the collector (best-effort)."""
|
|
_tool_result_collector.set(None)
|
|
|
|
|
|
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
|
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
|
_client_executor.set(fn)
|
|
|
|
|
|
def clear_client_executor() -> None:
|
|
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
|
|
try:
|
|
_client_executor.set(None) # type: ignore[arg-type]
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def execute_on_client(
|
|
action: str,
|
|
table: str | None = None,
|
|
data: dict[str, Any] | None = None,
|
|
filters: dict[str, Any] | None = None,
|
|
vector: list[float] | None = None,
|
|
limit: int | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Send a CRUD/vector operation to the Electron client and return the result.
|
|
|
|
Builds a ``tool_call`` payload, invokes the per-session WS callback,
|
|
and returns the ``tool_result`` dict from Electron.
|
|
|
|
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
|
|
"""
|
|
callback = _client_executor.get(None)
|
|
if callback is None:
|
|
raise RuntimeError(
|
|
"execute_on_client() called outside a WebSocket session — "
|
|
"no client executor is set."
|
|
)
|
|
|
|
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
|
|
if table is not None:
|
|
payload["table"] = table
|
|
if data is not None:
|
|
payload["data"] = data
|
|
if filters is not None:
|
|
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
|
if vector is not None:
|
|
payload["vector"] = vector
|
|
if limit is not None:
|
|
payload["limit"] = limit
|
|
|
|
result = await callback(payload)
|
|
result = _keys_to_camel(result)
|
|
collector = _tool_result_collector.get(None)
|
|
if collector is not None:
|
|
collector.append({
|
|
"action": action,
|
|
"table": table,
|
|
"data": result,
|
|
})
|
|
return result
|