Compare commits
5 Commits
main
...
c6e1e4e7fd
| Author | SHA1 | Date | |
|---|---|---|---|
| c6e1e4e7fd | |||
| cc603aba06 | |||
| 6d9a16e513 | |||
| 27c087d5d8 | |||
|
|
4d7fd519c5 |
243
AI_REFACTOR_PLAN.md
Normal file
243
AI_REFACTOR_PLAN.md
Normal file
@@ -0,0 +1,243 @@
|
||||
# AI Refactor Plan — Adiuva Backend
|
||||
|
||||
> **Objective:** Transform backend tools from JSON-action-descriptor-returning functions into real bidirectional executors. Each tool sends structured CRUD operations to the Electron client via WebSocket, receives real data back, and returns meaningful results to the LLM. The LLM reasons about actual user data instead of serialized action payloads.
|
||||
>
|
||||
> **Electron app:** Lives at `../adiuva/`. See `../adiuva/AI_REFACTOR_PLAN.md`.
|
||||
>
|
||||
> **Protocol:** Execute steps sequentially. Each step is atomic and committable. Mark `[x]` when done.
|
||||
|
||||
---
|
||||
|
||||
## Architecture — Before vs After
|
||||
|
||||
### Before (current)
|
||||
```
|
||||
LLM calls list_tasks(status="todo")
|
||||
→ tool returns: '{"action":"list","table":"tasks","filters":{"status":"todo"}}'
|
||||
→ _tool_loop feeds that JSON string as ToolMessage to LLM
|
||||
→ LLM sees a descriptor, NOT real data — cannot reason about tasks
|
||||
→ Final response: generic "Here are your tasks" (no actual task data)
|
||||
→ Action descriptors sent in final WS frame for Electron to execute post-response
|
||||
```
|
||||
|
||||
### After (target)
|
||||
```
|
||||
LLM calls list_tasks(status="todo")
|
||||
→ tool calls execute_on_client(action="select", table="tasks", filters={status:"todo"})
|
||||
→ WS frame sent to Electron: {type:"tool_call", id:"abc", action:"select", table:"tasks", filters:{status:"todo"}}
|
||||
→ Electron runs: db.select().from(tasks).where(eq(tasks.status, "todo")).all()
|
||||
→ WS frame back: {type:"tool_result", id:"abc", rows:[{id:"1",title:"Buy milk",...}, ...]}
|
||||
→ tool returns: "Found 3 tasks: 1. Buy milk (high, due tomorrow) 2. ..."
|
||||
→ _tool_loop feeds that as ToolMessage to LLM
|
||||
→ LLM sees REAL data — can reason, count, compare, summarize
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## WS Protocol — Typed Frames
|
||||
|
||||
| Direction | `type` | Payload |
|
||||
|---|---|---|
|
||||
| Client → Server | `chat_request` | `{ message: str, context: ChatContext }` |
|
||||
| Server → Client | `text_chunk` | `{ text: str }` |
|
||||
| Server → Client | `tool_call` | `{ id: str, action: str, table?: str, data?: dict, filters?: dict, vector?: list[float], limit?: int }` |
|
||||
| Client → Server | `tool_result` | `{ id: str, row?: dict, rows?: list[dict], results?: list[dict], deleted?: bool, ok?: bool, error?: str }` |
|
||||
| Server → Client | `final` | `{ response: str }` |
|
||||
| Server → Client | `ping` | `{}` |
|
||||
|
||||
**Actions:**
|
||||
|
||||
| `action` | What Electron does (Drizzle) | `tool_result` shape |
|
||||
|---|---|---|
|
||||
| `select` | `db.select().from(table).where(filters)` | `{ rows: [...] }` |
|
||||
| `get` | `db.select().from(table).where(id=...).get()` | `{ row: {...} or null }` |
|
||||
| `insert` | `db.insert(table).values({id: uuid(), ...data}).returning().get()` | `{ row: {...} }` |
|
||||
| `update` | `db.update(table).set(updates).where(id=...).returning().get()` | `{ row: {...} }` |
|
||||
| `delete` | `db.delete(table).where(id=...).run()` | `{ deleted: true }` |
|
||||
| `vector_upsert` | LanceDB upsert with pre-computed vector | `{ ok: true }` |
|
||||
| `vector_search` | LanceDB search by vector | `{ results: [{id, content, score}...] }` |
|
||||
|
||||
**Electron generates IDs + timestamps.** Backend tools never send `id` or `createdAt` in `insert` data — Electron adds `id: uuid()`, `createdAt: Date.now()`, `updatedAt: Date.now()`.
|
||||
|
||||
---
|
||||
|
||||
## SQLite Schema Reference (Electron's local database)
|
||||
|
||||
Tools must use **camelCase** field names (Drizzle maps them to snake_case internally):
|
||||
|
||||
| Table | Columns |
|
||||
|---|---|
|
||||
| `tasks` | id, projectId, title, description, status (todo\|in_progress\|done), priority (high\|medium\|low), assignee (JSON array string), dueDate (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
||||
| `projects` | id, clientId, name, status (active\|archived), aiSummary, createdAt (ms) |
|
||||
| `checkpoints` | id, projectId (required), title, date (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
||||
| `notes` | id, projectId, title, content (markdown), createdAt (ms), updatedAt (ms) |
|
||||
| `taskComments` | id, taskId, author, content, createdAt (ms) |
|
||||
| `clients` | id, parentId, name, industry, createdAt (ms) |
|
||||
|
||||
---
|
||||
|
||||
## Phase B — Backend Changes
|
||||
|
||||
### Step B.1 — WS context + frame types
|
||||
- [x] Create `app/core/ws_context.py` (~25 lines):
|
||||
- `_client_executor: ContextVar[Callable]` — holds the async callback for the current WS session
|
||||
- `async def execute_on_client(action, table=None, data=None, filters=None, vector=None, limit=None) -> dict`:
|
||||
- Reads callback from ContextVar
|
||||
- Builds `tool_call` payload: `{id: str(uuid4()), action, table, data, filters, vector, limit}` (omits None fields)
|
||||
- Calls `await callback(payload)` — which sends the WS frame and waits for `tool_result`
|
||||
- Returns the result dict
|
||||
- `def set_client_executor(fn)` / `def clear_client_executor()` — ContextVar management
|
||||
- [x] Add to `app/schemas.py`:
|
||||
- `WsFrameType(str, Enum)`: `chat_request`, `text_chunk`, `tool_call`, `tool_result`, `final`, `ping`
|
||||
- `WsToolCall(BaseModel)`: `type`, `id`, `action`, `table?`, `data?`, `filters?`, `vector?`, `limit?`
|
||||
- `WsToolResult(BaseModel)`: `type`, `id`, `row?`, `rows?`, `results?`, `deleted?`, `ok?`, `error?`
|
||||
- `WsTextChunk(BaseModel)`: `type`, `text`
|
||||
- `WsFinal(BaseModel)`: `type`, `response`
|
||||
- **Files:** `app/core/ws_context.py`, `app/schemas.py`
|
||||
- **Outcome:** Any tool can `await execute_on_client(...)` to query/mutate the user's local DB.
|
||||
|
||||
### Step B.2 — Rewrite all 23 tools to use `execute_on_client()`
|
||||
- [x] Each tool: same `@tool` decorator, same parameters, same docstring. Replace `return json.dumps({...})` body with:
|
||||
1. Call `result = await execute_on_client(action=..., table=..., data/filters=...)`
|
||||
2. Return human-readable string with confirmation + key data from `result`
|
||||
|
||||
- [x] **`app/agents/task_agent.py` (8 tools):**
|
||||
- `list_tasks(project_id, status, search, order_by)`:
|
||||
```python
|
||||
result = await execute_on_client(action="select", table="tasks", filters={
|
||||
"projectId": project_id or None,
|
||||
"status": status or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
})
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks found matching the given filters."
|
||||
lines = [f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||
```
|
||||
- `create_task(title, ...)`:
|
||||
```python
|
||||
result = await execute_on_client(action="insert", table="tasks", data={
|
||||
"title": title, "description": description or None, "status": status,
|
||||
"priority": priority, "assignee": assignees, "dueDate": due_date or None,
|
||||
"projectId": project_id or None, "isAiSuggested": is_ai_suggested, "isApproved": is_approved,
|
||||
})
|
||||
row = result["row"]
|
||||
return f"Task created: '{row['title']}' (id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||
```
|
||||
- `update_task(task_id, ...)`: build updates dict (same logic as now) → `execute_on_client(action="update", table="tasks", data={"id": task_id, "updates": updates})` → return "Task updated: {title}"
|
||||
- `delete_task(task_id)`: `execute_on_client(action="delete", table="tasks", data={"id": task_id})` → return "Task deleted"
|
||||
- `list_tasks_due_today()`: calculate today's start/end ms → `execute_on_client(action="select", table="tasks", filters={"dueDateFrom": start, "dueDateTo": end})` → format + return
|
||||
- `list_task_comments(task_id)`: `execute_on_client(action="select", table="taskComments", filters={"taskId": task_id})` → format + return
|
||||
- `add_task_comment(task_id, author, content)`: `execute_on_client(action="insert", table="taskComments", data={...})` → return confirmation
|
||||
- `delete_task_comment(comment_id)`: `execute_on_client(action="delete", table="taskComments", data={"id": comment_id})` → return confirmation
|
||||
|
||||
- [x] **`app/agents/project_agent.py` (6 tools):**
|
||||
- `list_projects(client_id, include_archived)`: `execute_on_client(action="select", table="projects", filters={clientId, includeArchived})` → format + return
|
||||
- `list_all_projects()`: `execute_on_client(action="select", table="projects")` → format + return
|
||||
- `get_project(project_id)`: `execute_on_client(action="get", table="projects", data={"id": project_id})` → return project details or "not found"
|
||||
- `create_project(name, client_id)`: `execute_on_client(action="insert", table="projects", data={name, clientId})` → return confirmation + id
|
||||
- `update_project(project_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
||||
- `delete_project(project_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
||||
|
||||
- [x] **`app/agents/checkpoint_agent.py` (4 tools):**
|
||||
- `list_checkpoints(project_id)`: `execute_on_client(action="select", table="checkpoints", filters={projectId})` → format + return
|
||||
- `create_checkpoint(project_id, title, date, ...)`: `execute_on_client(action="insert", table="checkpoints", data={...})` → return confirmation + id
|
||||
- `update_checkpoint(checkpoint_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
||||
- `delete_checkpoint(checkpoint_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
||||
|
||||
- [x] **`app/agents/note_agent.py` (5 tools):**
|
||||
- `list_notes(project_id)`: `execute_on_client(action="select", table="notes", filters={projectId})` → format + return
|
||||
- `get_note(note_id)`: `execute_on_client(action="get", table="notes", data={"id": note_id})` → return full content or "not found"
|
||||
- `create_note(title, content, project_id)`: `execute_on_client(action="insert", table="notes", data={...})` → then `execute_on_client(action="vector_upsert", data={id, projectId, content}, vector=await embed(content))` → return confirmation
|
||||
- `update_note(note_id, ...)`: build updates → `execute_on_client(action="update", ...)` → then vector_upsert for updated content → return confirmation
|
||||
- `delete_note(note_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
||||
|
||||
- **Files:** `app/agents/task_agent.py`, `app/agents/project_agent.py`, `app/agents/checkpoint_agent.py`, `app/agents/note_agent.py`
|
||||
- **Outcome:** All 23 tools query real user data via WS. LLM sees actual rows, not action descriptors.
|
||||
|
||||
### Step B.3 — Bidirectional WebSocket handler
|
||||
- [x] Refactor `app/api/routes/chat.py` WS endpoint:
|
||||
- After auth + accept + receive `chat_request`:
|
||||
1. Create `execute_on_client` callback closure capturing the websocket:
|
||||
```python
|
||||
pending_calls: dict[str, asyncio.Future] = {}
|
||||
|
||||
async def on_client_result(frame: dict):
|
||||
"""Called when a tool_result frame arrives from Electron."""
|
||||
fut = pending_calls.pop(frame["id"], None)
|
||||
if fut and not fut.done():
|
||||
fut.set_result(frame)
|
||||
|
||||
async def execute_callback(payload: dict) -> dict:
|
||||
"""Send tool_call to Electron, wait for tool_result."""
|
||||
call_id = payload["id"]
|
||||
fut = asyncio.get_event_loop().create_future()
|
||||
pending_calls[call_id] = fut
|
||||
await websocket.send_text(json.dumps({"type": "tool_call", **payload}))
|
||||
return await asyncio.wait_for(fut, timeout=30.0)
|
||||
```
|
||||
2. Set `client_executor` ContextVar with `execute_callback`
|
||||
3. Run orchestrator in a task — it calls agents, agents call tools, tools call `execute_on_client()` which goes through the callback
|
||||
4. In parallel, run a message receive loop that dispatches incoming frames:
|
||||
- `tool_result` → `on_client_result(frame)`
|
||||
- `ping` → ignore
|
||||
5. Orchestrator yields `text_chunk` frames → send to client
|
||||
6. Send `final` frame when done
|
||||
7. Clear ContextVar
|
||||
- Keep heartbeat ping every 30s
|
||||
- 30s timeout on `tool_result` — if Electron doesn't respond, future raises `TimeoutError`, tool returns error string to LLM
|
||||
- **Files:** `app/api/routes/chat.py`
|
||||
- **Outcome:** Full bidirectional WS. Tool calls and text streaming happen concurrently on the same connection.
|
||||
|
||||
### Step B.4 — `_tool_loop` — no changes needed
|
||||
- [x] Verify `app/core/agent_registry.py` works unchanged:
|
||||
- `_tool_loop` calls `tool_fn.ainvoke(args)` → tool awaits `execute_on_client()` (WS round-trip) → returns string → `ToolMessage(content=string)` → LLM sees real data
|
||||
- The async WS round-trip happens inside each tool. `_tool_loop` just sees an awaited tool returning a string — same as before, different content.
|
||||
- **No code changes.** Just verify + add a log line for tool execution times if desired.
|
||||
|
||||
### Step B.5 — Orchestrator cleanup
|
||||
- [x] Update `app/core/orchestrator.py`:
|
||||
- `orchestrate_stream()`: remove `"actions": []` from final frame. Final becomes: `{"done": true, "response": "..."}`
|
||||
- No other changes — `classify_intent` → `call_agent` → chunk response → final frame
|
||||
- **Files:** `app/core/orchestrator.py`
|
||||
- **Outcome:** Clean final frame. No more action descriptors in the protocol.
|
||||
|
||||
### Step B.6 — Add `/vectors/embed` endpoint
|
||||
- [x] Add to `app/api/routes/vectors.py`:
|
||||
- `POST /api/v1/storage/vectors/embed`:
|
||||
- Request: `{ text: str }`
|
||||
- Response: `{ vector: list[float] }` (1536-dim from `text-embedding-3-small`)
|
||||
- Auth required (JWT)
|
||||
- Used by:
|
||||
- Backend tools: `note_agent` calls this before `vector_upsert`
|
||||
- Electron: `vectordb.ts` calls this for note embedding on create/update
|
||||
- **Files:** `app/api/routes/vectors.py`
|
||||
- **Outcome:** Single embedding endpoint. Both backend tools and Electron can generate vectors.
|
||||
|
||||
---
|
||||
|
||||
## Verification
|
||||
|
||||
| What to test | How |
|
||||
|---|---|
|
||||
| **Read flow** | "List my tasks" → `list_tasks` → `tool_call{select, tasks}` → Electron returns rows → LLM describes real tasks |
|
||||
| **Write flow** | "Create a task called Buy milk" → `create_task` → `tool_call{insert, tasks, data:{title:"Buy milk"}}` → Electron inserts + returns row → tool confirms with id |
|
||||
| **Multi-tool** | "How many todo tasks do I have?" → `list_tasks(status=todo)` → LLM counts actual rows → "You have 3 todo tasks" |
|
||||
| **Vector search** | "Find notes about deployment" → tool embeds → `tool_call{vector_search, vector:[...]}` → Electron searches LanceDB → returns matching notes |
|
||||
| **Vector upsert** | "Create a note about..." → insert note → vector_upsert with embedding → both SQLite + LanceDB updated |
|
||||
| **Tool timeout** | Disconnect Electron mid-conversation → 30s timeout → tool returns error → LLM handles gracefully |
|
||||
| **Concurrent calls** | Agent calls 2 tools in sequence → each does WS round-trip → both succeed → LLM sees both results |
|
||||
| **_tool_loop max iter** | Verify 5-iteration limit still works → after 5 tool calls, LLM forced to answer without tools |
|
||||
|
||||
---
|
||||
|
||||
## Execution Notes
|
||||
|
||||
- **Phase 1 is the critical path.** Auth + backend client + drizzle executor + orchestrator refactor must land first.
|
||||
- **Steps 1.1–1.4 are additive** — existing app keeps working until Step 1.5 swaps the orchestrator.
|
||||
- **Step 2.1 is the point of no return** — after removing LangChain, there's no local AI fallback.
|
||||
- **Phase B (backend changes) must land before Phase 1.3–1.5** — Electron needs the bidirectional WS to talk to.
|
||||
- **Phase 3 and Phase 4 are independent** — can be parallelized after Phase 2.
|
||||
- **One step at a time.** Mark `[x]` and commit with `step N.N complete: <outcome>`.
|
||||
@@ -21,18 +21,25 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Enum types ────────────────────────────────────────────────────────
|
||||
billing_tier = postgresql.ENUM(
|
||||
"free", "pro", "power", "team", name="billing_tier", create_type=False
|
||||
)
|
||||
plugin_status = postgresql.ENUM(
|
||||
"pending_review", "approved", "rejected", name="plugin_status", create_type=False
|
||||
)
|
||||
review_decision = postgresql.ENUM(
|
||||
"approved", "rejected", name="review_decision", create_type=False
|
||||
)
|
||||
for enum in (billing_tier, plugin_status, review_decision):
|
||||
enum.create(op.get_bind(), checkfirst=True)
|
||||
# ── Enum types — idempotent creation via exception handling ───────────
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE billing_tier AS ENUM ('free', 'pro', 'power', 'team');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
|
||||
# ── users ─────────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
@@ -40,7 +47,7 @@ def upgrade() -> None:
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("email", sa.String(255), nullable=False),
|
||||
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
@@ -70,7 +77,7 @@ def upgrade() -> None:
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
|
||||
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
|
||||
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
@@ -125,7 +132,7 @@ def upgrade() -> None:
|
||||
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||
sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
||||
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
||||
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||
@@ -157,7 +164,7 @@ def upgrade() -> None:
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||
sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
||||
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
||||
sa.Column("notes", sa.Text, nullable=True),
|
||||
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
@@ -10,6 +9,7 @@ from langchain_core.tools import tool
|
||||
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
from app.core.llm import get_llm
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
|
||||
@@ -28,11 +28,16 @@ _SYSTEM_PROMPT = (
|
||||
@tool
|
||||
async def list_checkpoints(project_id: str = "") -> str:
|
||||
"""List checkpoints. Provide project_id to scope to a specific project."""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "checkpoints",
|
||||
"filters": {"projectId": project_id or None},
|
||||
})
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="checkpoints",
|
||||
filters={"projectId": project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No checkpoints found."
|
||||
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} checkpoint(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -50,17 +55,19 @@ async def create_checkpoint(
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
is_approved: 0 until the user confirms
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "checkpoints",
|
||||
"data": {
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="checkpoints",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"title": title,
|
||||
"date": date,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
"isApproved": is_approved,
|
||||
},
|
||||
})
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Checkpoint created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||
|
||||
|
||||
@tool
|
||||
@@ -82,21 +89,20 @@ async def update_checkpoint(
|
||||
updates["date"] = date
|
||||
if is_approved != -1:
|
||||
updates["isApproved"] = is_approved
|
||||
return json.dumps({
|
||||
"action": "update_record",
|
||||
"table": "checkpoints",
|
||||
"data": {"id": checkpoint_id, "updates": updates},
|
||||
})
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="checkpoints",
|
||||
data={"id": checkpoint_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Checkpoint updated: '{row['title']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_checkpoint(checkpoint_id: str) -> str:
|
||||
"""Delete a checkpoint permanently by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "checkpoints",
|
||||
"data": {"id": checkpoint_id},
|
||||
})
|
||||
await execute_on_client(action="delete", table="checkpoints", data={"id": checkpoint_id})
|
||||
return f"Checkpoint {checkpoint_id} deleted."
|
||||
|
||||
|
||||
@registry.register
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
from app.core.llm import get_llm
|
||||
from app.core.llm import embed, get_llm
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||
@@ -29,21 +29,26 @@ _SYSTEM_PROMPT = (
|
||||
@tool
|
||||
async def list_notes(project_id: str = "") -> str:
|
||||
"""List notes, optionally scoped to a project by project_id."""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "notes",
|
||||
"filters": {"projectId": project_id or None},
|
||||
})
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="notes",
|
||||
filters={"projectId": project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No notes found."
|
||||
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_note(note_id: str) -> str:
|
||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||
return json.dumps({
|
||||
"action": "get",
|
||||
"table": "notes",
|
||||
"data": {"id": note_id},
|
||||
})
|
||||
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Note {note_id} not found."
|
||||
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||
|
||||
|
||||
@tool
|
||||
@@ -57,15 +62,24 @@ async def create_note(
|
||||
content: Markdown body text (required)
|
||||
project_id: optional UUID linking this note to a project
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "notes",
|
||||
"data": {
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="notes",
|
||||
data={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"projectId": project_id or None,
|
||||
},
|
||||
})
|
||||
)
|
||||
row = result["row"]
|
||||
# Index the note content in the vector store.
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
@@ -83,21 +97,28 @@ async def update_note(
|
||||
updates["title"] = title
|
||||
if content:
|
||||
updates["content"] = content
|
||||
return json.dumps({
|
||||
"action": "update_record",
|
||||
"table": "notes",
|
||||
"data": {"id": note_id, "updates": updates},
|
||||
})
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="notes",
|
||||
data={"id": note_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
# Re-index if content changed.
|
||||
if content:
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_note(note_id: str) -> str:
|
||||
"""Delete a note permanently by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "notes",
|
||||
"data": {"id": note_id},
|
||||
})
|
||||
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||
return f"Note {note_id} deleted."
|
||||
|
||||
|
||||
@registry.register
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
@@ -10,6 +9,7 @@ from langchain_core.tools import tool
|
||||
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
from app.core.llm import get_llm
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a project management assistant. You help users create, find,\n"
|
||||
@@ -36,14 +36,19 @@ async def list_projects(
|
||||
"""List projects, optionally filtered by client_id.
|
||||
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "projects",
|
||||
"filters": {
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="projects",
|
||||
filters={
|
||||
"clientId": client_id or None,
|
||||
"includeArchived": bool(include_archived),
|
||||
},
|
||||
})
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -51,20 +56,25 @@ async def list_all_projects() -> str:
|
||||
"""List every project regardless of client or status.
|
||||
Use only when the user wants a complete cross-client overview.
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "list_all",
|
||||
"table": "projects",
|
||||
})
|
||||
result = await execute_on_client(action="select", table="projects")
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_project(project_id: str) -> str:
|
||||
"""Fetch a single project by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "get",
|
||||
"table": "projects",
|
||||
"data": {"id": project_id},
|
||||
})
|
||||
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Project {project_id} not found."
|
||||
return (
|
||||
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
||||
f"clientId: {row.get('clientId', 'none')})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -76,14 +86,13 @@ async def create_project(
|
||||
name: human-readable project name (required)
|
||||
client_id: optional UUID of the owning client
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "projects",
|
||||
"data": {
|
||||
"name": name,
|
||||
"clientId": client_id or None,
|
||||
},
|
||||
})
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="projects",
|
||||
data={"name": name, "clientId": client_id or None},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project created: '{row['name']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
@@ -108,11 +117,13 @@ async def update_project(
|
||||
updates["status"] = status
|
||||
if ai_summary:
|
||||
updates["aiSummary"] = ai_summary
|
||||
return json.dumps({
|
||||
"action": "update_record",
|
||||
"table": "projects",
|
||||
"data": {"id": project_id, "updates": updates},
|
||||
})
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="projects",
|
||||
data={"id": project_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
||||
|
||||
|
||||
@tool
|
||||
@@ -121,11 +132,8 @@ async def delete_project(project_id: str) -> str:
|
||||
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||
has explicitly confirmed they want permanent deletion.
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "projects",
|
||||
"data": {"id": project_id},
|
||||
})
|
||||
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
||||
return f"Project {project_id} permanently deleted."
|
||||
|
||||
|
||||
@registry.register
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
@@ -10,6 +10,7 @@ from langchain_core.tools import tool
|
||||
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
from app.core.llm import get_llm
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a task management assistant for a project workspace.\n"
|
||||
@@ -41,16 +42,24 @@ async def list_tasks(
|
||||
) -> str:
|
||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "tasks",
|
||||
"filters": {
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={
|
||||
"projectId": project_id or None,
|
||||
"status": status or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
},
|
||||
})
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks found matching the given filters."
|
||||
lines = [
|
||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -76,10 +85,10 @@ async def create_task(
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
is_approved: 0 until the user confirms; 1 when confirmed
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "tasks",
|
||||
"data": {
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="tasks",
|
||||
data={
|
||||
"title": title,
|
||||
"description": description or None,
|
||||
"status": status,
|
||||
@@ -90,7 +99,12 @@ async def create_task(
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
"isApproved": is_approved,
|
||||
},
|
||||
})
|
||||
)
|
||||
row = result["row"]
|
||||
return (
|
||||
f"Task created: '{row['title']}' "
|
||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -127,30 +141,41 @@ async def update_task(
|
||||
updates["projectId"] = project_id
|
||||
if is_approved != -1:
|
||||
updates["isApproved"] = is_approved
|
||||
return json.dumps({
|
||||
"action": "update_record",
|
||||
"table": "tasks",
|
||||
"data": {"id": task_id, "updates": updates},
|
||||
})
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="tasks",
|
||||
data={"id": task_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task(task_id: str) -> str:
|
||||
"""Delete a task permanently by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "tasks",
|
||||
"data": {"id": task_id},
|
||||
})
|
||||
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||
return f"Task {task_id} deleted."
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks_due_today() -> str:
|
||||
"""List all tasks whose due date falls on today's date."""
|
||||
return json.dumps({
|
||||
"action": "list_due_today",
|
||||
"table": "tasks",
|
||||
})
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks are due today."
|
||||
lines = [
|
||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
# ── Task comment tools ────────────────────────────────────────────────
|
||||
@@ -159,11 +184,16 @@ async def list_tasks_due_today() -> str:
|
||||
@tool
|
||||
async def list_task_comments(task_id: str) -> str:
|
||||
"""List all comments on a task by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "taskComments",
|
||||
"filters": {"taskId": task_id},
|
||||
})
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="taskComments",
|
||||
filters={"taskId": task_id},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return f"No comments found for task {task_id}."
|
||||
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -173,25 +203,20 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||
author: name or ID of the comment author
|
||||
content: comment text
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "taskComments",
|
||||
"data": {
|
||||
"taskId": task_id,
|
||||
"author": author,
|
||||
"content": content,
|
||||
},
|
||||
})
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="taskComments",
|
||||
data={"taskId": task_id, "author": author, "content": content},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task_comment(comment_id: str) -> str:
|
||||
"""Delete a task comment by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "taskComments",
|
||||
"data": {"id": comment_id},
|
||||
})
|
||||
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||
return f"Comment {comment_id} deleted."
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Vectors routes: upsert, search, and delete cloud vector store entries."""
|
||||
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.llm import embed
|
||||
from app.schemas import (
|
||||
UserProfile,
|
||||
VectorSearchRequest,
|
||||
@@ -24,6 +25,14 @@ class _VectorDeleteRequest(BaseModel):
|
||||
ids: list[str]
|
||||
|
||||
|
||||
class _EmbedRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class _EmbedResponse(BaseModel):
|
||||
vector: list[float]
|
||||
|
||||
|
||||
@router.post("/vectors/upsert", response_model=dict)
|
||||
async def upsert_vectors(
|
||||
body: VectorUpsertRequest,
|
||||
@@ -54,3 +63,17 @@ async def delete_vectors(
|
||||
"""Delete vectors by ID, scoped to the authenticated user."""
|
||||
await _vector_store.delete(current_user.id, body.ids)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/vectors/embed", response_model=_EmbedResponse)
|
||||
async def embed_text(
|
||||
body: _EmbedRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _EmbedResponse:
|
||||
"""Generate a 1536-dim embedding vector for the given text.
|
||||
|
||||
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
|
||||
"""
|
||||
vector = await embed(body.text)
|
||||
return _EmbedResponse(vector=vector)
|
||||
|
||||
@@ -17,6 +17,8 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
||||
|
||||
@@ -66,3 +68,13 @@ def get_router_llm(
|
||||
) -> ChatOpenAI:
|
||||
"""Return the lighter model used for intent classification / routing."""
|
||||
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
||||
|
||||
|
||||
async def embed(text: str) -> list[float]:
|
||||
"""Return a 1536-dim embedding vector for *text* using text-embedding-3-small."""
|
||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
response = await client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=text,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@@ -144,14 +144,15 @@ async def orchestrate_stream(
|
||||
request: ChatRequest,
|
||||
reg: AgentRegistry | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming orchestration — yields text chunks then a final JSON frame.
|
||||
"""Streaming orchestration — yields plain text chunks only.
|
||||
|
||||
The final frame is a JSON object:
|
||||
``{"done": true, "response": "...", "actions": []}``.
|
||||
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
|
||||
wrapping each chunk in a ``text_chunk`` frame and sending the final
|
||||
``final`` frame once the generator is exhausted.
|
||||
|
||||
Agents do not yet support token-level streaming; the full response is
|
||||
fetched first, then emitted in fixed-size chunks. Token-level streaming
|
||||
will be wired in Step 6 when agents expose ``astream()``.
|
||||
fetched first (which may involve multiple WS round-trips for tool calls),
|
||||
then emitted in fixed-size chunks.
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
@@ -163,6 +164,3 @@ async def orchestrate_stream(
|
||||
chunk_size = 50
|
||||
for i in range(0, len(response_text), chunk_size):
|
||||
yield response_text[i : i + chunk_size]
|
||||
|
||||
final = ChatResponse(response=response_text)
|
||||
yield json.dumps({"done": True, **final.model_dump()})
|
||||
|
||||
68
app/core/ws_context.py
Normal file
68
app/core/ws_context.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""WebSocket client executor context.
|
||||
|
||||
Holds a per-request async callback that tools call to execute CRUD
|
||||
operations on the Electron client's local SQLite / LanceDB databases.
|
||||
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Coroutine
|
||||
from uuid import uuid4
|
||||
|
||||
# Holds the execute callback for the current WS session.
|
||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||
"_client_executor"
|
||||
)
|
||||
|
||||
|
||||
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||
_client_executor.set(fn)
|
||||
|
||||
|
||||
def clear_client_executor() -> None:
|
||||
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
|
||||
try:
|
||||
_client_executor.set(None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def execute_on_client(
|
||||
action: str,
|
||||
table: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
vector: list[float] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a CRUD/vector operation to the Electron client and return the result.
|
||||
|
||||
Builds a ``tool_call`` payload, invokes the per-session WS callback,
|
||||
and returns the ``tool_result`` dict from Electron.
|
||||
|
||||
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
|
||||
"""
|
||||
callback = _client_executor.get(None)
|
||||
if callback is None:
|
||||
raise RuntimeError(
|
||||
"execute_on_client() called outside a WebSocket session — "
|
||||
"no client executor is set."
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
|
||||
if table is not None:
|
||||
payload["table"] = table
|
||||
if data is not None:
|
||||
payload["data"] = data
|
||||
if filters is not None:
|
||||
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||
if vector is not None:
|
||||
payload["vector"] = vector
|
||||
if limit is not None:
|
||||
payload["limit"] = limit
|
||||
|
||||
return await callback(payload)
|
||||
@@ -5,6 +5,7 @@ Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -155,3 +156,54 @@ class PluginListResponse(BaseModel):
|
||||
|
||||
class PluginInstallRequest(BaseModel):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||
|
||||
class WsFrameType(str, Enum):
|
||||
chat_request = "chat_request"
|
||||
text_chunk = "text_chunk"
|
||||
tool_call = "tool_call"
|
||||
tool_result = "tool_result"
|
||||
final = "final"
|
||||
ping = "ping"
|
||||
|
||||
|
||||
class WsToolCall(BaseModel):
|
||||
"""Server → Client: requests a CRUD/vector operation on the local DB."""
|
||||
|
||||
type: Literal[WsFrameType.tool_call] = WsFrameType.tool_call
|
||||
id: str
|
||||
action: str
|
||||
table: str | None = None
|
||||
data: dict[str, Any] | None = None
|
||||
filters: dict[str, Any] | None = None
|
||||
vector: list[float] | None = None
|
||||
limit: int | None = None
|
||||
|
||||
|
||||
class WsToolResult(BaseModel):
|
||||
"""Client → Server: result of a CRUD/vector operation."""
|
||||
|
||||
type: Literal[WsFrameType.tool_result] = WsFrameType.tool_result
|
||||
id: str
|
||||
row: dict[str, Any] | None = None
|
||||
rows: list[dict[str, Any]] | None = None
|
||||
results: list[dict[str, Any]] | None = None
|
||||
deleted: bool | None = None
|
||||
ok: bool | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class WsTextChunk(BaseModel):
|
||||
"""Server → Client: incremental LLM response text."""
|
||||
|
||||
type: Literal[WsFrameType.text_chunk] = WsFrameType.text_chunk
|
||||
text: str
|
||||
|
||||
|
||||
class WsFinal(BaseModel):
|
||||
"""Server → Client: signals end of response with the complete text."""
|
||||
|
||||
type: Literal[WsFrameType.final] = WsFrameType.final
|
||||
response: str
|
||||
|
||||
Reference in New Issue
Block a user