10 Commits

Author SHA1 Message Date
47bf1881e5 deep agent 2026-03-12 18:03:27 +01:00
24a9c1b752 refactor: migrate from create_react_agent to create_deep_agent
- Replace langgraph create_react_agent with deepagents create_deep_agent
- Sub-agents now configured as SubAgent dicts dispatched via built-in task tool
- Stream filter updated: langgraph_node 'agent' → 'model'
- Accept both AIMessage and AIMessageChunk in stream filter
- Collector only captures write mutations (insert/update/delete)
- Add deepagents>=0.4.10 to requirements.txt
2026-03-12 01:21:14 +01:00
706bf88883 fix KeyError 'JSON' — escape braces in chart prompt for str.format() 2026-03-12 00:44:41 +01:00
4ff0b27084 removed WsStreamBlock class 2026-03-12 00:38:31 +01:00
61d2a18234 remove WsStreamBlock schema and tests — no longer used 2026-03-12 00:37:40 +01:00
b3687719b6 add chart tag instructions to HOME system prompt 2026-03-12 00:28:43 +01:00
f80bdfa8f7 simplify HomeFormatter to pass-through — frontend handles entity tag parsing 2026-03-12 00:10:38 +01:00
617a17db40 feat: HomeFormatter parses inline entity tags instead of tool_end blocks
The supervisor LLM now embeds <type>[id1,id2]</type> entity tags in its
response text. The HomeFormatter buffers streamed tokens, detects complete
tags across chunk boundaries, and emits WsStreamBlock with entity type +
specific IDs. This replaces the old approach of emitting blocks for every
tool_end event, which dumped ALL entities regardless of relevance.

Also fixes:
- NoneType guard on metadata in _run_graph_stream (metadata can be None)
- Updated _HOME_SYSTEM prompt with entity tag instructions
- Updated all affected tests
2026-03-12 00:01:06 +01:00
92716cb89a fix: pass tool name as positional arg to @tool decorator
The langchain @tool decorator expects the name as the first positional
argument (name_or_callable), not as name= keyword argument.
2026-03-11 23:32:14 +01:00
cfc9d7a942 refactor: replace orchestrator with LangGraph deep-agent supervisors
- Add app/core/deep_agent.py with Home and Floating supervisor graphs
  using LangGraph create_react_agent (hierarchical pattern)
- Strip ChatAgent classes from all 4 agent files, keep @tool functions
- Rewrite output_formatter.py for event-based (token/tool_end/mutations) stream
- Update device_ws.py to use run_home_stream/run_floating_stream
- Rewrite chat.py REST route to use run_home
- Add update_core_memory tool to both supervisors
- Add langgraph>=0.3.0 to requirements.txt
- Remove orchestrator.py, execution_plan.py, agent_registry.py, plans.py
- Remove PlanAction, PlanStep, ExecutionPlan, execution_mode from schemas
- Update all affected tests to match new API
- Remove 6 deprecated test files for deleted modules
- Clean up stale docstrings referencing removed orchestrator
2026-03-11 17:50:22 +01:00
35 changed files with 791 additions and 5024 deletions

View File

@@ -1,523 +0,0 @@
# AI Refactor Plan — Adiuva Backend
> **Objective:** Transform backend tools from JSON-action-descriptor-returning functions into real bidirectional executors. Each tool sends structured CRUD operations to the Electron client via WebSocket, receives real data back, and returns meaningful results to the LLM. The LLM reasons about actual user data instead of serialized action payloads.
>
> **Electron app:** Lives at `../adiuva/`. See `../adiuva/AI_REFACTOR_PLAN.md`.
>
> **Protocol:** Execute steps sequentially. Each step is atomic and committable. Mark `[x]` when done.
---
## Architecture — Before vs After
### Before (current)
```
LLM calls list_tasks(status="todo")
→ tool returns: '{"action":"list","table":"tasks","filters":{"status":"todo"}}'
→ _tool_loop feeds that JSON string as ToolMessage to LLM
→ LLM sees a descriptor, NOT real data — cannot reason about tasks
→ Final response: generic "Here are your tasks" (no actual task data)
→ Action descriptors sent in final WS frame for Electron to execute post-response
```
### After (target)
```
LLM calls list_tasks(status="todo")
→ tool calls execute_on_client(action="select", table="tasks", filters={status:"todo"})
→ WS frame sent to Electron: {type:"tool_call", id:"abc", action:"select", table:"tasks", filters:{status:"todo"}}
→ Electron runs: db.select().from(tasks).where(eq(tasks.status, "todo")).all()
→ WS frame back: {type:"tool_result", id:"abc", rows:[{id:"1",title:"Buy milk",...}, ...]}
→ tool returns: "Found 3 tasks: 1. Buy milk (high, due tomorrow) 2. ..."
→ _tool_loop feeds that as ToolMessage to LLM
→ LLM sees REAL data — can reason, count, compare, summarize
```
---
## WS Protocol — Typed Frames
| Direction | `type` | Payload |
|---|---|---|
| Client → Server | `chat_request` | `{ message: str, context: ChatContext }` |
| Server → Client | `text_chunk` | `{ text: str }` |
| Server → Client | `tool_call` | `{ id: str, action: str, table?: str, data?: dict, filters?: dict, vector?: list[float], limit?: int }` |
| Client → Server | `tool_result` | `{ id: str, row?: dict, rows?: list[dict], results?: list[dict], deleted?: bool, ok?: bool, error?: str }` |
| Server → Client | `final` | `{ response: str }` |
| Server → Client | `ping` | `{}` |
**Actions:**
| `action` | What Electron does (Drizzle) | `tool_result` shape |
|---|---|---|
| `select` | `db.select().from(table).where(filters)` | `{ rows: [...] }` |
| `get` | `db.select().from(table).where(id=...).get()` | `{ row: {...} or null }` |
| `insert` | `db.insert(table).values({id: uuid(), ...data}).returning().get()` | `{ row: {...} }` |
| `update` | `db.update(table).set(updates).where(id=...).returning().get()` | `{ row: {...} }` |
| `delete` | `db.delete(table).where(id=...).run()` | `{ deleted: true }` |
| `vector_upsert` | LanceDB upsert with pre-computed vector | `{ ok: true }` |
| `vector_search` | LanceDB search by vector | `{ results: [{id, content, score}...] }` |
**Electron generates IDs + timestamps.** Backend tools never send `id` or `createdAt` in `insert` data — Electron adds `id: uuid()`, `createdAt: Date.now()`, `updatedAt: Date.now()`.
---
## SQLite Schema Reference (Electron's local database)
Tools must use **camelCase** field names (Drizzle maps them to snake_case internally):
| Table | Columns |
|---|---|
| `tasks` | id, projectId, title, description, status (todo\|in_progress\|done), priority (high\|medium\|low), assignee (JSON array string), dueDate (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
| `projects` | id, clientId, name, status (active\|archived), aiSummary, createdAt (ms) |
| `timelines` | id, projectId (required), title, date (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
| `notes` | id, projectId, title, content (markdown), createdAt (ms), updatedAt (ms) |
| `taskComments` | id, taskId, author, content, createdAt (ms) |
| `clients` | id, parentId, name, industry, createdAt (ms) |
---
## Phase B — Backend Changes
### Step B.1 — WS context + frame types
- [x] Create `app/core/ws_context.py` (~25 lines):
- `_client_executor: ContextVar[Callable]` — holds the async callback for the current WS session
- `async def execute_on_client(action, table=None, data=None, filters=None, vector=None, limit=None) -> dict`:
- Reads callback from ContextVar
- Builds `tool_call` payload: `{id: str(uuid4()), action, table, data, filters, vector, limit}` (omits None fields)
- Calls `await callback(payload)` — which sends the WS frame and waits for `tool_result`
- Returns the result dict
- `def set_client_executor(fn)` / `def clear_client_executor()` — ContextVar management
- [x] Add to `app/schemas.py`:
- `WsFrameType(str, Enum)`: `chat_request`, `text_chunk`, `tool_call`, `tool_result`, `final`, `ping`
- `WsToolCall(BaseModel)`: `type`, `id`, `action`, `table?`, `data?`, `filters?`, `vector?`, `limit?`
- `WsToolResult(BaseModel)`: `type`, `id`, `row?`, `rows?`, `results?`, `deleted?`, `ok?`, `error?`
- `WsTextChunk(BaseModel)`: `type`, `text`
- `WsFinal(BaseModel)`: `type`, `response`
- **Files:** `app/core/ws_context.py`, `app/schemas.py`
- **Outcome:** Any tool can `await execute_on_client(...)` to query/mutate the user's local DB.
### Step B.2 — Rewrite all 23 tools to use `execute_on_client()`
- [x] Each tool: same `@tool` decorator, same parameters, same docstring. Replace `return json.dumps({...})` body with:
1. Call `result = await execute_on_client(action=..., table=..., data/filters=...)`
2. Return human-readable string with confirmation + key data from `result`
- [x] **`app/agents/task_agent.py` (8 tools):**
- `list_tasks(project_id, status, search, order_by)`:
```python
result = await execute_on_client(action="select", table="tasks", filters={
"projectId": project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
})
rows = result.get("rows", [])
if not rows:
return "No tasks found matching the given filters."
lines = [f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
```
- `create_task(title, ...)`:
```python
result = await execute_on_client(action="insert", table="tasks", data={
"title": title, "description": description or None, "status": status,
"priority": priority, "assignee": assignees, "dueDate": due_date or None,
"projectId": project_id or None, "isAiSuggested": is_ai_suggested, "isApproved": is_approved,
})
row = result["row"]
return f"Task created: '{row['title']}' (id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
```
- `update_task(task_id, ...)`: build updates dict (same logic as now) → `execute_on_client(action="update", table="tasks", data={"id": task_id, "updates": updates})` → return "Task updated: {title}"
- `delete_task(task_id)`: `execute_on_client(action="delete", table="tasks", data={"id": task_id})` → return "Task deleted"
- `list_tasks_due_today()`: calculate today's start/end ms → `execute_on_client(action="select", table="tasks", filters={"dueDateFrom": start, "dueDateTo": end})` → format + return
- `list_task_comments(task_id)`: `execute_on_client(action="select", table="taskComments", filters={"taskId": task_id})` → format + return
- `add_task_comment(task_id, author, content)`: `execute_on_client(action="insert", table="taskComments", data={...})` → return confirmation
- `delete_task_comment(comment_id)`: `execute_on_client(action="delete", table="taskComments", data={"id": comment_id})` → return confirmation
- [x] **`app/agents/project_agent.py` (6 tools):**
- `list_projects(client_id, include_archived)`: `execute_on_client(action="select", table="projects", filters={clientId, includeArchived})` → format + return
- `list_all_projects()`: `execute_on_client(action="select", table="projects")` → format + return
- `get_project(project_id)`: `execute_on_client(action="get", table="projects", data={"id": project_id})` → return project details or "not found"
- `create_project(name, client_id)`: `execute_on_client(action="insert", table="projects", data={name, clientId})` → return confirmation + id
- `update_project(project_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
- `delete_project(project_id)`: `execute_on_client(action="delete", ...)` → return confirmation
- [x] **`app/agents/timeline_agent.py` (4 tools):**
- `list_timelines(project_id)`: `execute_on_client(action="select", table="timelines", filters={projectId})` → format + return
- `create_timeline(project_id, title, date, ...)`: `execute_on_client(action="insert", table="timelines", data={...})` → return confirmation + id
- `update_timeline(timeline_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
- `delete_timeline(timeline_id)`: `execute_on_client(action="delete", ...)` → return confirmation
- [x] **`app/agents/note_agent.py` (5 tools):**
- `list_notes(project_id)`: `execute_on_client(action="select", table="notes", filters={projectId})` → format + return
- `get_note(note_id)`: `execute_on_client(action="get", table="notes", data={"id": note_id})` → return full content or "not found"
- `create_note(title, content, project_id)`: `execute_on_client(action="insert", table="notes", data={...})` → then `execute_on_client(action="vector_upsert", data={id, projectId, content}, vector=await embed(content))` → return confirmation
- `update_note(note_id, ...)`: build updates → `execute_on_client(action="update", ...)` → then vector_upsert for updated content → return confirmation
- `delete_note(note_id)`: `execute_on_client(action="delete", ...)` → return confirmation
- **Files:** `app/agents/task_agent.py`, `app/agents/project_agent.py`, `app/agents/timeline_agent.py`, `app/agents/note_agent.py`
- **Outcome:** All 23 tools query real user data via WS. LLM sees actual rows, not action descriptors.
### Step B.3 — Bidirectional WebSocket handler
- [x] Refactor `app/api/routes/chat.py` WS endpoint:
- After auth + accept + receive `chat_request`:
1. Create `execute_on_client` callback closure capturing the websocket:
```python
pending_calls: dict[str, asyncio.Future] = {}
async def on_client_result(frame: dict):
"""Called when a tool_result frame arrives from Electron."""
fut = pending_calls.pop(frame["id"], None)
if fut and not fut.done():
fut.set_result(frame)
async def execute_callback(payload: dict) -> dict:
"""Send tool_call to Electron, wait for tool_result."""
call_id = payload["id"]
fut = asyncio.get_event_loop().create_future()
pending_calls[call_id] = fut
await websocket.send_text(json.dumps({"type": "tool_call", **payload}))
return await asyncio.wait_for(fut, timeout=30.0)
```
2. Set `client_executor` ContextVar with `execute_callback`
3. Run orchestrator in a task — it calls agents, agents call tools, tools call `execute_on_client()` which goes through the callback
4. In parallel, run a message receive loop that dispatches incoming frames:
- `tool_result` → `on_client_result(frame)`
- `ping` → ignore
5. Orchestrator yields `text_chunk` frames → send to client
6. Send `final` frame when done
7. Clear ContextVar
- Keep heartbeat ping every 30s
- 30s timeout on `tool_result` — if Electron doesn't respond, future raises `TimeoutError`, tool returns error string to LLM
- **Files:** `app/api/routes/chat.py`
- **Outcome:** Full bidirectional WS. Tool calls and text streaming happen concurrently on the same connection.
### Step B.4 — `_tool_loop` — no changes needed
- [x] Verify `app/core/agent_registry.py` works unchanged:
- `_tool_loop` calls `tool_fn.ainvoke(args)` → tool awaits `execute_on_client()` (WS round-trip) → returns string → `ToolMessage(content=string)` → LLM sees real data
- The async WS round-trip happens inside each tool. `_tool_loop` just sees an awaited tool returning a string — same as before, different content.
- **No code changes.** Just verify + add a log line for tool execution times if desired.
### Step B.5 — Orchestrator cleanup
- [x] Update `app/core/orchestrator.py`:
- `orchestrate_stream()`: remove `"actions": []` from final frame. Final becomes: `{"done": true, "response": "..."}`
- No other changes — `classify_intent` → `call_agent` → chunk response → final frame
- **Files:** `app/core/orchestrator.py`
- **Outcome:** Clean final frame. No more action descriptors in the protocol.
### Step B.6 — Add `/vectors/embed` endpoint
- [x] Add to `app/api/routes/vectors.py`:
- `POST /api/v1/storage/vectors/embed`:
- Request: `{ text: str }`
- Response: `{ vector: list[float] }` (1536-dim from `text-embedding-3-small`)
- Auth required (JWT)
- Used by:
- Backend tools: `note_agent` calls this before `vector_upsert`
- Electron: `vectordb.ts` calls this for note embedding on create/update
- **Files:** `app/api/routes/vectors.py`
- **Outcome:** Single embedding endpoint. Both backend tools and Electron can generate vectors.
---
## Verification
| What to test | How |
|---|---|
| **Read flow** | "List my tasks" → `list_tasks` → `tool_call{select, tasks}` → Electron returns rows → LLM describes real tasks |
| **Write flow** | "Create a task called Buy milk" → `create_task` → `tool_call{insert, tasks, data:{title:"Buy milk"}}` → Electron inserts + returns row → tool confirms with id |
| **Multi-tool** | "How many todo tasks do I have?" → `list_tasks(status=todo)` → LLM counts actual rows → "You have 3 todo tasks" |
| **Vector search** | "Find notes about deployment" → tool embeds → `tool_call{vector_search, vector:[...]}` → Electron searches LanceDB → returns matching notes |
| **Vector upsert** | "Create a note about..." → insert note → vector_upsert with embedding → both SQLite + LanceDB updated |
| **Tool timeout** | Disconnect Electron mid-conversation → 30s timeout → tool returns error → LLM handles gracefully |
| **Concurrent calls** | Agent calls 2 tools in sequence → each does WS round-trip → both succeed → LLM sees both results |
| **_tool_loop max iter** | Verify 5-iteration limit still works → after 5 tool calls, LLM forced to answer without tools |
---
## Execution Notes
- **Phase 1 is the critical path.** Auth + backend client + drizzle executor + orchestrator refactor must land first.
- **Steps 1.11.4 are additive** — existing app keeps working until Step 1.5 swaps the orchestrator.
- **Step 2.1 is the point of no return** — after removing LangChain, there's no local AI fallback.
- **Phase B (backend changes) must land before Phase 1.31.5** — Electron needs the bidirectional WS to talk to.
- **Phase 3 and Phase 4 are independent** — can be parallelized after Phase 2.
---
## Phase 3 — Agent System: Config, Orchestration & Cloud Connectors
> **Objective:** Backend manages all agent configuration, scheduling, orchestration, and cloud data fetching. Two agent types: **Local Directory Agent** (backend triggers Electron to read files, then AI analyzes) and **Cloud Connector Agent** (backend fetches Gmail/Teams data directly, AI analyzes, pushes results to Electron via WS tool_call). All extracted items use existing WS tool infrastructure to insert into Electron's local DB with `is_ai_suggested=True`.
>
> **Electron Phase 3 plan:** `../adiuva/AI_REFACTOR_PLAN.md` Phase 3 section.
>
> **Electron UI status (2025):** Steps 3.6, 3.7, 3.8 of the Electron plan are ✅ complete. Agents are configured inside the Settings page (`/settings?section=agents`) — not a standalone route. The `JourneyDialog` (Step 3.8) is embedded inline in the Settings → Agents section. `LocalAgentConfigPanel` and `CloudAgentConfigPanel` (Step 3.7) are also inline. This affects the journey API contract (see Step 3.5 below).
### Architecture
```
Local Agent:
Scheduler/manual trigger ──► check device online ──► WS agent_run → Electron
──► Electron reads files ──► WS agent_data → Backend
──► Backend AI (prompt_template + file content) ──► WS tool_call(insert) → Electron
──► Electron persists with isAiSuggested=1
Cloud Agent:
Scheduler/manual trigger ──► Backend fetches Gmail/Teams (OAuth) ──► Backend AI analyzes
──► check device online ──► WS tool_call(insert) → Electron ──► Electron persists
```
**New WS frame types:**
| Direction | `type` | Payload |
|---|---|---|
| Server → Client | `agent_run` | `{ run_id, agent_id, config: { paths, file_extensions, prompt_template, data_types } }` |
| Client → Server | `agent_data` | `{ run_id, files: [{ path, name, content, metadata }] }` |
| Client → Server | `agent_complete` | `{ run_id, files_read, errors }` |
| Client → Server | `device_hello` | `{ device_id, agent_ids }` |
### Step 3.1 — Agent config tables
- [x] Add to `app/models.py`:
- **`LocalAgentConfig`**:
- `id` UUID PK
- `user_id` FK → users
- `device_id` str — identifies which Electron install this config belongs to
- `name` str
- `directory_paths` JSON — list of absolute paths on the device
- `data_types` JSON — which tables to extract to: `["tasks", "notes", "timelines", "projects"]`
- `prompt_template` text — user-configured via Chatbot Journey
- `file_extensions` JSON — e.g. `[".eml", ".txt", ".pdf", ".md"]`
- `schedule_cron` str — e.g. `"0 */6 * * *"` (every 6h)
- `enabled` bool (default True)
- `last_run_at` datetime nullable
- `created_at`, `updated_at` timestamps
- **`CloudAgentConfig`**:
- `id` UUID PK
- `user_id` FK → users
- `provider` str — enum: `gmail`, `teams`, `outlook`
- `name` str
- `data_types` JSON — same format as local
- `prompt_template` text
- `oauth_token_encrypted` text — Fernet-encrypted OAuth2 credentials
- `schedule_cron` str
- `enabled` bool (default True)
- `last_run_at` datetime nullable
- `filter_config` JSON — provider-specific: `{ labels: [], date_range: {from, to}, senders: [] }`
- `created_at`, `updated_at` timestamps
- **`AgentRunLog`**:
- `id` UUID PK
- `agent_id` str — references LocalAgentConfig.id or CloudAgentConfig.id
- `agent_type` str — `local` or `cloud`
- `user_id` FK → users
- `status` str — `running`, `success`, `error`, `partial`
- `items_processed` int (default 0)
- `items_created` int (default 0)
- `errors` JSON — list of error strings
- `started_at` datetime
- `completed_at` datetime nullable
- [x] Add Pydantic schemas to `app/schemas.py`:
- `LocalAgentConfigCreate`, `LocalAgentConfigUpdate`, `LocalAgentConfigResponse`
- `CloudAgentConfigCreate`, `CloudAgentConfigUpdate`, `CloudAgentConfigResponse`
- `AgentRunLogResponse`
- `AgentCatalogItem` — `{ type, name, description, config_schema }`
- `WsAgentRun`, `WsAgentData`, `WsAgentComplete`, `WsDeviceHello`
- [x] Generate Alembic migration
- **Files:** `app/models.py`, `app/schemas.py`, `alembic/versions/`
- **Outcome:** Agent config and run tracking tables in PostgreSQL.
### Step 3.2 — Agent CRUD API routes
- [x] Create `app/api/routes/agents.py`:
- `GET /api/v1/agents/catalog` — returns hardcoded agent type catalog:
- `local_directory`: "Watches local directories, extracts data from files using AI"
- `gmail`: "Scans Gmail inbox, extracts tasks/notes from emails"
- `teams`: "Monitors Teams messages, extracts action items"
- `outlook`: "Scans Outlook inbox, extracts tasks/notes"
- `GET /api/v1/agents/local` — list user's local agent configs
- `POST /api/v1/agents/local` — create local agent config
- Body: `{ name, device_id, directory_paths, data_types, prompt_template, file_extensions, schedule_cron }`
- Tier check: count enabled agents ≤ `batch_active` limit
- `PUT /api/v1/agents/local/{id}` — update config (ownership check)
- `DELETE /api/v1/agents/local/{id}` — delete config + associated run logs
- `GET /api/v1/agents/cloud` — list user's cloud agent configs
- `POST /api/v1/agents/cloud` — create cloud connector config
- Body: `{ provider, name, data_types, prompt_template, oauth_token_encrypted, schedule_cron, filter_config }`
- Tier check: same `batch_active` limit (local + cloud count together)
- `PUT /api/v1/agents/cloud/{id}` — update config
- `DELETE /api/v1/agents/cloud/{id}` — delete config + run logs
- `GET /api/v1/agents/runs` — query params: `agent_id`, `page`, `limit` → paginated run logs
- `POST /api/v1/agents/{id}/run` — manual trigger (dispatches to agent runner)
- All routes require JWT auth; ownership enforced on all mutations
- [x] Register router in `app/main.py`
- **Files:** `app/api/routes/agents.py`, `app/main.py`
- **Outcome:** Full CRUD for agent configs with tier-gated creation limits.
### Step 3.3 — Device WS endpoint
- [x] Create `app/api/routes/device_ws.py`:
- `WebSocket /api/v1/ws/device?token=<jwt>` — persistent connection from Electron
- On connect:
- Authenticate JWT
- Receive `device_hello` frame → extract `device_id`, `agent_ids`
- Store connection in `DeviceConnectionManager` (in-memory dict: `user_id → { ws, device_id }`)
- Check for overdue agent runs → trigger them immediately
- Message loop:
- `agent_data` → route to active agent run handler
- `agent_complete` → finalize agent run
- `tool_result` → route to pending tool call (same pattern as chat WS)
- `pong` → heartbeat ack
- On disconnect:
- Remove from `DeviceConnectionManager`
- Mark any in-progress agent runs as `error` with "device disconnected"
- Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s
- [x] Create `app/core/device_manager.py`:
- `DeviceConnectionManager` (singleton):
- `register(user_id, device_id, ws)` — stores active connection
- `unregister(user_id)` — removes connection
- `get_ws(user_id) -> WebSocket | None` — returns active WS if device is online
- `is_online(user_id, device_id=None) -> bool` — optionally checks specific device
- `send_frame(user_id, frame: dict)` — sends JSON frame to device
- **Files:** `app/api/routes/device_ws.py`, `app/core/device_manager.py`, `app/main.py`
- **Outcome:** Backend maintains persistent WS connections to Electron devices for agent triggers.
### Step 3.4 — Agent run orchestrator
- [x] Create `app/core/agent_runner.py`:
- `async run_local_agent(user_id, config: LocalAgentConfig, device_mgr: DeviceConnectionManager)`:
1. Check device is online with matching `device_id` → abort if offline
2. Create `AgentRunLog` with `status=running`
3. Send `WsAgentRun` frame to Electron with config (paths, extensions, prompt)
4. Await `WsAgentData` frames — collect file contents
5. Await `WsAgentComplete` frame — Electron signals done reading
6. For each file: call LLM with `prompt_template` + file content → extract structured items
7. For each extracted item: send `WsToolCall(insert, table, data)` to Electron → await `WsToolResult`
- All inserts include `is_ai_suggested=True, is_approved=False`
8. Update `AgentRunLog`: `status=success`, `items_processed`, `items_created`
- `async run_cloud_agent(user_id, config: CloudAgentConfig, device_mgr: DeviceConnectionManager)`:
1. Check device is online → abort if offline (results must push to Electron)
2. Create `AgentRunLog` with `status=running`
3. Decrypt OAuth credentials from `config.oauth_token_encrypted`
4. Fetch data from cloud provider (Step 3.6):
- Gmail: `google-api-python-client` + `filter_config` label/date filters
- Teams: `msgraph-sdk` + channel/date filters
- Outlook: `msgraph-sdk` + folder/date filters
5. For each item: call LLM with `prompt_template` + email/message content → extract structured items
6. For each extracted item: send `WsToolCall(insert)` to Electron → await `WsToolResult`
7. Update `AgentRunLog`
- `async trigger_pending_runs(user_id, device_id, device_mgr)`:
- Called when Electron connects (after `device_hello`)
- Queries all enabled agent configs where `last_run_at + schedule_interval < now()`
- For local agents: only triggers if `config.device_id == device_id`
- For cloud agents: triggers regardless of device (any connected device can receive results)
- Executes runs sequentially (one at a time to avoid overwhelming the WS)
- Error handling: on any failure, update `AgentRunLog` with `status=error` + error details
- [x] Wire `POST /agents/{id}/run` endpoint to dispatch background task via `asyncio.create_task()`
- [x] Replace `_trigger_pending_runs_stub` in `device_ws.py` with real `trigger_pending_runs` call
- [x] Add `croniter>=3.0.0` to `requirements.txt`
- [x] 23 unit + integration tests covering all code paths
- **Files:** `app/core/agent_runner.py`, `app/api/routes/agents.py`, `app/api/routes/device_ws.py`, `requirements.txt`, `tests/test_agent_runner.py`
- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls — stub until Step 3.6).
### Step 3.5 — Chatbot Journey endpoint
- [x] Create `app/api/routes/agent_setup.py`:
- `POST /api/v1/agents/journey/start`:
- Body: `{ agent_type: "local"|"cloud", agent_id: str | None }`
- `agent_type`: which kind of agent this journey configures.
- `agent_id`: optional — if provided, the session is pre-seeded with the existing agent's `prompt_template` so the user can refine it. If absent, fresh journey.
- **No `data_types` field** — data types are determined through the conversation itself, not sent upfront.
- Creates a journey session (in-memory or Redis-backed)
- Returns first AI message: contextual question based on agent type
- Local: "What kind of files are in the directories you want to monitor? (emails, documents, logs, etc.)"
- Cloud: "What kind of emails/messages should I look for? (client communications, invoices, meeting notes, etc.)"
- Response: `{ session_id, message, done: false }`
- **Electron note:** `proxyPost` auto-converts camelCase keys to snake_case. Electron sends `{ agentType, agentId }` → backend receives `{ agent_type, agent_id }`.
- `POST /api/v1/agents/journey/message`:
- Body: `{ session_id, message }`
- AI processes user's answer, asks follow-up questions (max 5 turns)
- System prompt: "You are configuring a data extraction agent for a freelancer. Ask about file format, what data to extract (tasks, notes, timelines), naming conventions, priority rules, and any special mapping. After 3-5 questions, generate a detailed prompt_template."
- When AI determines enough context: `{ session_id, message: "Here's your configuration...", done: true, prompt_template: "..." }`
- The `prompt_template` is a structured instruction for the extraction LLM (e.g. "Extract tasks from email. Subject becomes task title. If body contains 'urgent' or 'ASAP', set priority to 'high'. Extract due dates if mentioned.")
- **Electron note:** `toCamelCase` converts the response → Electron reads `promptTemplate` from the final message and auto-fills the agent config panel. User clicks "Save & apply" which calls `agent.local.update` / `agent.cloud.update` tRPC mutation.
- **Files:** `app/api/routes/agent_setup.py`, `app/main.py`
- **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅
### Step 3.6 — Cloud provider integrations
- [x] Create `app/integrations/gmail.py`:
- `GmailClient`:
- `__init__(oauth_token)` — initializes Google API client
- `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]`
- `EmailMessage`: `{ id, subject, sender, body_text, date, labels }`
- Handles token refresh via Google OAuth2 refresh flow
- Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders`
- [x] Create `app/integrations/ms_graph.py`:
- `MSGraphClient`:
- `__init__(oauth_token)` — initializes MS Graph client
- `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook)
- `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams)
- `ChatMessage`: `{ id, content, sender, channel, date }`
- Handles token refresh via MSAL
- [x] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient`
- **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal`
- **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py`
- **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams.
### Step 3.7 — Agent scheduler
- [ ] Create `app/core/agent_scheduler.py`:
- Uses `APScheduler` (or simple asyncio loop) to check agent schedules
- Every 60s: query enabled agents where `last_run_at + cron_interval < now()`
- For each due agent:
- Check if user's device is online via `DeviceConnectionManager`
- If online: dispatch to `agent_runner`
- If offline: skip (will trigger on next `device_hello`)
- Locks: use PostgreSQL advisory locks to prevent duplicate runs in multi-instance deployments
- [ ] Integrate with FastAPI lifespan (start scheduler on app startup, shutdown gracefully)
- **Dependencies:** `apscheduler>=4.0`
- **Files:** `app/core/agent_scheduler.py`, `app/main.py`
- **Outcome:** Agents run automatically on their configured schedules.
### Step 3.8 — OAuth flow endpoints
- [ ] Create `app/api/routes/oauth.py`:
- `GET /api/v1/oauth/{provider}/authorize` — returns OAuth authorization URL
- Gmail: Google OAuth2 with `gmail.readonly` scope
- Outlook/Teams: MS identity platform with `Mail.Read`, `ChannelMessage.Read.All` scopes
- `GET /api/v1/oauth/{provider}/callback` — handles OAuth redirect
- Exchanges auth code for access + refresh tokens
- Encrypts tokens with Fernet (server-side key from settings)
- Returns encrypted token blob for storage in `CloudAgentConfig.oauth_token_encrypted`
- `POST /api/v1/oauth/{provider}/refresh` — refresh expired OAuth token
- **Files:** `app/api/routes/oauth.py`, `app/main.py`
- **Outcome:** Users can connect Gmail/Teams/Outlook accounts securely.
---
### Phase 3 — Verification
| # | Scenario | Expected |
|---|---|---|
| 1 | **Agent CRUD** | Create/read/update/delete local and cloud configs; tier limits enforced (free=2, pro=10) |
| 2 | **WS device connect** | Electron connects → `device_hello` → backend stores connection → triggers overdue runs |
| 3 | **Local agent run** | Backend sends `agent_run` → Electron reads files → `agent_data` → backend AI extracts → `tool_call(insert)` → Electron persists with `isAiSuggested=1` |
| 4 | **Cloud agent run** | Backend fetches Gmail → AI extracts tasks → `tool_call(insert)` → Electron persists |
| 5 | **Device binding** | Local agent config with `device_id=A` only triggers when device A is connected |
| 6 | **Chatbot Journey** | Start journey → 3-5 Q&A turns → produces valid `prompt_template` |
| 7 | **Schedule** | Agent with `schedule_cron="0 */6 * * *"` runs every 6h when device is online |
| 8 | **Offline resilience** | Device offline → runs skipped → device reconnects → overdue runs trigger immediately |
| 9 | **OAuth flow** | Gmail authorize → callback → token encrypted → stored in config → fetch emails works |
### Phase 3 — New Dependencies
| Package | Purpose |
|---|---|
| `google-api-python-client` | Gmail API access |
| `google-auth-oauthlib` | Gmail OAuth2 flow |
| `msgraph-sdk` | Outlook + Teams API access |
| `msal` | MS identity platform auth |
| `apscheduler>=4.0` | Agent scheduling |
| `cryptography` (Fernet) | OAuth token encryption at rest |
---
## ~~Phase 5 — Shared Memory~~ (SUPERSEDED)
> **This phase has been fully replaced by `V3_MIGRATION_PLAN.md`.**
>
> - Chat WS fix → V3 Step 5 (Unified WS Handler — single multiplexed socket)
> - Agent memory → V3 Steps 67 (Cloud-side MemGPT-style memory in PostgreSQL + pgvector, encrypted at rest with per-user Fernet key)
>
> The on-device KV approach (Electron SQLite `agent_memory` table) is no longer the target architecture.
> See `V3_MIGRATION_PLAN.md` for the current plan.

View File

@@ -1,572 +0,0 @@
# Backend Plan — Adiuva Cloud API
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
>
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
---
## Project Structure
```
adiuva-api/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI entry + CORS + lifespan + router includes
│ ├── core/
│ │ ├── __init__.py
│ │ ├── agent_registry.py # Base classes + singleton registry
│ │ ├── orchestrator.py # LLM-based intent router
│ │ ├── execution_plan.py # Plan builder + cache
│ │ └── plugin_loader.py # Dynamic agent loading
│ ├── agents/ # Chat agents (proprietary logic + prompts)
│ │ ├── __init__.py # Auto-registers all agents
│ │ ├── task_agent.py
│ │ ├── calendar_agent.py
│ │ ├── email_agent.py
│ │ └── analytics_agent.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── routes/
│ │ │ ├── __init__.py
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
│ │ │ ├── plans.py # GET /plans/playbook
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
│ │ │ ├── vectors.py # Upsert/search cloud vector store
│ │ │ ├── backup.py # PUT/GET /backup
│ │ │ ├── plugins.py # Plugin marketplace
│ │ │ ├── auth.py # Register/login/refresh
│ │ │ └── billing.py # Checkout/webhook/subscription
│ │ └── middleware/
│ │ ├── __init__.py
│ │ ├── auth.py # JWT validation
│ │ ├── rate_limit.py # Tier-aware rate limiting
│ │ └── sanitizer.py # Strip prompt metadata from responses
│ ├── storage/
│ │ ├── __init__.py
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
│ │ └── encryption.py # Integrity verification only — NO decryption
│ ├── marketplace/
│ │ ├── __init__.py
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
│ │ ├── plugin_review.py # Review queue + approval workflow
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
│ ├── billing/
│ │ ├── __init__.py
│ │ ├── stripe_service.py # Stripe checkout + webhooks
│ │ └── tier_manager.py # Feature matrix per tier
│ └── config/
│ ├── __init__.py
│ └── settings.py # Pydantic BaseSettings (env-based)
├── tests/
│ ├── __init__.py
│ ├── conftest.py # Fixtures: test client, mock agents, mock LLM
│ ├── test_orchestrator.py
│ ├── test_agents.py
│ ├── test_auth.py
│ ├── test_backup.py
│ ├── test_storage.py
│ └── test_plugins.py
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
│ ├── alembic.ini
│ └── versions/
├── requirements.txt
├── Dockerfile
├── docker-compose.yml # App + PostgreSQL + Redis (dev)
├── .env.example
└── README.md
```
---
## Step-by-Step Implementation
### Step 1 — Project scaffolding ✅
- [x] Initialize repo with the directory structure above
- [x] Write `requirements.txt`:
```
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
langchain>=0.3.0
langchain-openai>=0.3.0
pydantic>=2.10.0
python-jose[cryptography]>=3.3.0
stripe>=11.0.0
boto3>=1.35.0
slowapi>=0.1.9
sqlalchemy>=2.0.0
asyncpg>=0.30.0
alembic>=1.14.0
bcrypt>=4.2.0
python-dotenv>=1.0.0
httpx>=0.28.0
websockets>=14.0
pytest>=8.0.0
pytest-asyncio>=0.24.0
```
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
- [x] Write `.env.example`
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
### Step 2 — Pydantic schemas (API contracts) ✅
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
- `PluginInstallRequest`: `plugin_id: str`
- **Outcome:** All request/response models defined and validated.
### Step 3 — Agent Registry + base classes ✅
- [x] `app/core/agent_registry.py`:
- `BaseAgent(ABC)`:
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
- Abstract `get_name() -> str`, `get_description() -> str`
- `ChatAgent(BaseAgent)`:
- Abstract `async handle(query: str, context: dict) -> str`
- Abstract `get_tools() -> list` (LangChain tool definitions)
- Concrete `_tool_loop(llm, messages, tools, max_iter=5) -> str` — shared tool-calling loop
- `AgentRegistry` (singleton):
- `_agents: dict[str, ChatAgent]`
- `register(agent_class)` — decorator pattern
- `get(name) -> ChatAgent`
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
- `async call_agent(name, query, context) -> str` — for inter-agent calls
- [x] Unit tests: register, get, list, call_agent with mock
- **Outcome:** Pluggable agent framework.
### Step 4 — Orchestrator ✅
- [x] `app/core/orchestrator.py`:
- `async classify_intent(message, context, registry) -> str`:
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
- Uses gpt-4o-mini via LangChain for low latency
- Falls back to `task_agent` if no clear match
- `async route_single(agent_name, message, context) -> ChatResponse`:
- Instantiates agent from registry
- Calls `agent.handle(message, context)`
- Returns response + any actions the agent produced
- `async route_pipeline(agent_names, message, context) -> ChatResponse`:
- Executes agents in sequence
- Each agent receives `{...context, previous_results: [...]}`
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
- Main entry point
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
- Classifies intent
- If `execution_mode == 'direct'`: route + return response
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
- Same as orchestrate but yields tokens for WebSocket streaming
- [x] Integration tests with mocked LLM and mocked agents
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
### Step 5 — Execution Plan generator ✅
- [x] `app/core/execution_plan.py`:
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
- `ExecutionPlanBuilder`:
- `add_step(action, params) -> self`
- `add_llm_step(template_id, variables) -> self`
- `add_data_step(action, data_from_step) -> self`
- `build() -> ExecutionPlan` — validates step references
- `PlanCache`:
- In-memory LRU (maxsize=1000)
- `cache_plan(key, plan)`, `get_plan(key)`, `get_all_playbooks() -> list[ExecutionPlan]`
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
### Step 6 — Chat Agents ✅
- [x] `app/agents/task_agent.py` — `@registry.register`:
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
- Accepts flexible context; sentinel `-1` for optional integer update fields
- [x] `app/agents/timeline_agent.py` — `@registry.register`:
- Description: "Manages project timelines (milestones): list, create, update, delete"
- Tools (4): `list_timelines(project_id)`, `create_timeline(project_id, title, date, is_ai_suggested, is_approved)`, `update_timeline(timeline_id, ...)`, `delete_timeline(timeline_id)`
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
- [x] `app/agents/project_agent.py` — `@registry.register`:
- Description: "Manages projects: list, get, create, update, archive, delete"
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
- [x] `app/agents/note_agent.py` — `@registry.register`:
- Description: "Manages notes: list, get, create, update, delete"
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
- content is Markdown; `get_note` should be called before update to preserve existing content
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Timelines, Projects, Notes), all registered and tested.
### Step 7 — Storage Layer ✅
- [x] `app/storage/blob_store.py`:
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
- [x] `app/storage/vector_store.py`:
- `VectorStore`: `async upsert`, `async search`, `async delete`
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
- ANN on encrypted data: known accuracy trade-off, documented
- [x] `app/storage/encryption.py`:
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
- Backend NEVER holds decryption keys
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
### Step 8 — API Routes ✅
#### 8a — Chat endpoint
- [x] `app/api/routes/chat.py`:
- `POST /api/v1/chat`:
- Request: `ChatRequest`
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
- Response: `ChatResponse` or `ExecutionPlan`
- `WebSocket /api/v1/chat/stream`:
- Client sends `ChatRequest` as first JSON frame
- Server yields token strings via `orchestrate_stream()`
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
- Heartbeat ping every 30s to keep connection alive
#### 8b — Plans endpoint
- [x] `app/api/routes/plans.py`:
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
#### 8c — Storage endpoint (cloud records)
- [x] `app/api/routes/storage.py`:
- `POST /api/v1/storage/records`: Create encrypted record
- Request: `StorageRecordCreate`
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
- Response: `{id: str, created_at: int}`
- `GET /api/v1/storage/records`: List record metadata (no blobs)
- Query params: `table: str`, `page: int`, `limit: int`
- Response: `list[{id, table, checksum, created_at, updated_at}]`
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
- Response: blob bytes + `X-Checksum` header
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
- Request: `StorageRecordUpdate`
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
#### 8d — Vectors endpoint (cloud vector store)
- [x] `app/api/routes/vectors.py`:
- `POST /api/v1/storage/vectors/upsert`:
- Request: `VectorUpsertRequest`
- Verifies checksums, delegates to `VectorStore.upsert()`
- Response: `{upserted: int}`
- `POST /api/v1/storage/vectors/search`:
- Request: `VectorSearchRequest`
- Delegates to `VectorStore.search()`
- Response: `VectorSearchResponse`
- `DELETE /api/v1/storage/vectors`:
- Request: `{ids: list[str]}`
#### 8e — Backup endpoint
- [x] `app/api/routes/backup.py`:
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
- Free: 0 (no backup)
- Pro: 5 GB
- Power: 25 GB
- Team: unlimited
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
#### 8f — Plugins endpoint
- [x] `app/api/routes/plugins.py`:
- `GET /api/v1/plugins`:
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
- Response: `PluginListResponse`
- Available from Power tier and above
- `GET /api/v1/plugins/{id}`:
- Response: `PluginManifest` + ratings + install count
- `POST /api/v1/plugins/{id}/install`:
- Request: `PluginInstallRequest`
- Records installation for the user (billing tracking, analytics)
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
- `DELETE /api/v1/plugins/{id}/install`:
- Unregisters installation
#### 8g — Auth endpoint
- [x] `app/api/routes/auth.py`:
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
#### 8h — Billing endpoint
- [x] `app/api/routes/billing.py`:
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
- `GET /api/v1/billing/subscription`: Returns current subscription info
- `DELETE /api/v1/billing/subscription`: Cancels subscription
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
### Step 9 — Middleware
#### 9a — Auth middleware
- [x] `app/api/middleware/auth.py`:
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
- Validates JWT signature, expiry, extracts `user_id` and `tier`
- Raises `401` on invalid/expired token
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
#### 9b — Rate limiter
- [x] `app/api/middleware/rate_limit.py`:
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
- Tier-based limits:
- Free: 20 req/min
- Pro: 60 req/min
- Power: 120 req/min
- Team: 200 req/seat/min
- Custom 429 response with `Retry-After` header
#### 9c — Sanitizer
- [x] `app/api/middleware/sanitizer.py`:
- Response middleware that scans response bodies
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
- Pattern-based detection + exact match against known prompt fingerprints
- Logs sanitization events for monitoring
- **Outcome:** Secure, rate-limited API with prompt IP protection.
### Step 10 — Plugin Marketplace ✅
- [x] `app/marketplace/plugin_registry.py`:
- `PluginRegistry`:
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
- `async get_plugin(plugin_id) -> PluginManifest | None`
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
- `async reject_plugin(plugin_id, reason: str) -> None`
- [x] `app/marketplace/plugin_review.py`:
- `ReviewQueue`:
- `async get_pending() -> list[dict]`
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
- [x] `app/marketplace/revenue_share.py`:
- `RevenueShare`:
- `async record_install(plugin_id, user_id, amount_cents) -> None`
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
- `async get_earnings(developer_id, period) -> dict`
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
### Step 11 — Billing & Tier management ✅
- [x] `app/billing/stripe_service.py`:
- `create_checkout_session(user_id, tier) -> str`
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
- `get_subscription(user_id) -> dict | None`
- `cancel_subscription(user_id) -> None`
- [x] `app/billing/tier_manager.py`:
- `TierManager`:
- Feature matrix:
```python
FEATURES = {
'free': {
'agents': 3,
'batch_active': 2,
'cloud_storage_gb': 0,
'backup_gb': 0,
'providers': 1,
'batch_builder': False,
'plugin_marketplace': False,
'sso': False,
},
'pro': {
'agents': -1, # unlimited
'batch_active': 10,
'cloud_storage_gb': 5,
'backup_gb': 5,
'providers': -1,
'batch_builder': False,
'plugin_marketplace': False,
'sso': False,
},
'power': {
'agents': -1,
'batch_active': -1, # unlimited
'cloud_storage_gb': 25,
'backup_gb': 25,
'providers': -1,
'batch_builder': True,
'plugin_marketplace': True,
'sso': False,
},
'team': {
'agents': -1,
'batch_active': -1,
'cloud_storage_gb': -1,
'backup_gb': -1,
'providers': -1,
'batch_builder': True,
'plugin_marketplace': True,
'sso': True,
},
}
```
- `get_tier(user_id) -> BillingTier`
- `check_feature(user_id, feature) -> bool`
- `get_rate_limit(tier) -> int`
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
### Step 12 — Database (auth/billing/marketplace only)
- [x] PostgreSQL schema via Alembic:
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
- [x] Initial Alembic migration
- [x] SQLAlchemy models in `app/models.py`
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
### Step 13 — Testing & deployment ✅
- [x] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
- [x] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
- [x] `tests/test_agents.py`: each agent with mocked tools
- [x] `tests/test_auth.py`: register → login → access protected → refresh → expired token
- [x] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
- [x] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
- [x] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
- [x] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
- [x] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
- **Outcome:** Fully tested, deployable backend.
---
## API Contract Summary
| Method | Endpoint | Auth | Request | Response |
|--------|----------|------|---------|----------|
| POST | `/api/v1/auth/register` | No | `{email, password}` | `AuthTokens` |
| POST | `/api/v1/auth/login` | No | `{email, password}` | `AuthTokens` |
| POST | `/api/v1/auth/refresh` | No | `{refresh_token}` | `AuthTokens` |
| GET | `/api/v1/auth/me` | JWT | — | `UserProfile` |
| POST | `/api/v1/chat` | JWT | `ChatRequest` | `ChatResponse \| ExecutionPlan` |
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
| GET | `/api/v1/backup` | JWT | — | Binary blob |
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
| DELETE | `/api/v1/billing/subscription` | JWT | — | `{ok: true}` |
| GET | `/api/v1/health` | No | — | `{status, version}` |
| GET | `/api/v1/agents/catalog` | JWT | — | `AgentCatalogItem[]` |
| GET | `/api/v1/agents/local` | JWT | — | `LocalAgentConfigResponse[]` |
| POST | `/api/v1/agents/local` | JWT | `LocalAgentConfigCreate` | `LocalAgentConfigResponse` |
| PUT | `/api/v1/agents/local/{id}` | JWT | `LocalAgentConfigUpdate` | `LocalAgentConfigResponse` |
| DELETE | `/api/v1/agents/local/{id}` | JWT | — | `{ok: true}` |
| GET | `/api/v1/agents/cloud` | JWT | — | `CloudAgentConfigResponse[]` |
| POST | `/api/v1/agents/cloud` | JWT | `CloudAgentConfigCreate` | `CloudAgentConfigResponse` |
| PUT | `/api/v1/agents/cloud/{id}` | JWT | `CloudAgentConfigUpdate` | `CloudAgentConfigResponse` |
| DELETE | `/api/v1/agents/cloud/{id}` | JWT | — | `{ok: true}` |
| GET | `/api/v1/agents/runs` | JWT | `?agent_id&page&limit` | `AgentRunLogResponse[]` |
| POST | `/api/v1/agents/{id}/run` | JWT | — | `{ok: true, run_id}` |
| POST | `/api/v1/agents/journey/start` | JWT | `{agent_type, data_types}` | `{session_id, message, done}` |
| POST | `/api/v1/agents/journey/message` | JWT | `{session_id, message}` | `{session_id, message, done, prompt_template?}` |
| GET | `/api/v1/oauth/{provider}/authorize` | JWT | — | `{authorization_url}` |
| GET | `/api/v1/oauth/{provider}/callback` | — | OAuth code | `{encrypted_token}` |
| WS | `/api/v1/ws/device` | JWT | `device_hello` (first frame) | Agent trigger + tool_call frames |
---
## Stack
| Layer | Technology |
|-------|-----------|
| Framework | FastAPI + Uvicorn |
| LLM | LangChain + langchain-openai |
| Auth | PyJWT + bcrypt + OAuth2 |
| Billing | stripe-python + Stripe Connect |
| Blob storage | boto3 (S3) |
| Vector store | Pinecone or Qdrant (configurable) |
| Database | PostgreSQL + SQLAlchemy + Alembic |
| Rate limiting | slowapi |
| Cloud integrations | google-api-python-client, msgraph-sdk, msal |
| Agent scheduling | APScheduler |
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
| Deployment | Docker → fly.io / Railway / AWS ECS |
---
## Phase 3 — New Files
| File | Purpose |
|---|---|
| `app/models.py` | Add `LocalAgentConfig`, `CloudAgentConfig`, `AgentRunLog` models |
| `app/schemas.py` | Add agent config schemas + WS agent frame types |
| `app/api/routes/agents.py` | Agent CRUD endpoints (catalog, local, cloud, runs, manual trigger) |
| `app/api/routes/agent_setup.py` | Chatbot Journey endpoints (start + message) |
| `app/api/routes/device_ws.py` | Persistent device WS endpoint (`/api/v1/ws/device`) |
| `app/api/routes/oauth.py` | OAuth authorize/callback for Gmail, Teams, Outlook |
| `app/core/agent_runner.py` | Agent run orchestration — local (WS file request) + cloud (API fetch) |
| `app/core/device_manager.py` | `DeviceConnectionManager` — tracks active Electron WS connections |
| `app/core/agent_scheduler.py` | Periodic scheduler for agent cron triggers |
| `app/integrations/gmail.py` | Gmail API client (fetch messages with filters) |
| `app/integrations/ms_graph.py` | MS Graph client for Outlook emails + Teams messages |
| `app/integrations/__init__.py` | Provider factory |
> **Full Phase 3 step-by-step plan:** See `AI_REFACTOR_PLAN.md` Phase 3 section.
---
## Development Rules
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
5. **Type hints everywhere.** All functions have full type annotations.
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
7. **Structured logging.** JSON logs with request ID correlation.
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.

View File

@@ -1,353 +0,0 @@
# V3 Migration Plan — Multi-Agent AI Productivity App
> Incremental migration from current architecture to v3.
> Each step is self-contained, testable, and backwards-compatible.
> No BYOK — server manages all LLM keys.
> Memory encryption: server-side per-user Fernet key (Option A).
---
## General Rules
**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes:
- Old functions/methods that are superseded by new ones
- Deprecated imports or modules
- Dead code paths
- Old test files no longer needed
This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant.
---
## Decisions Log
| Topic | Decision |
|---|---|
| WS topology | Single multiplexed socket (merge chat into device WS) |
| LLM keys | Server-managed only, no user key passthrough |
| Memory encryption | Per-user server-generated Fernet key, encrypted at rest, decrypted in-memory |
| device_manager | Already multi-user correct (keyed by user_id), no structural change |
---
## Step 1 — WS Frame Protocol (schemas.py)
**Goal**: Define the v3 frame vocabulary so all subsequent steps can import it.
**Changes**:
- `app/schemas.py` — Add to `WsFrameType` enum:
- `home_request`, `floating_request`
- `stream_start`, `stream_text`, `stream_block`, `stream_end`
- `floating_domain`
- `data_request`, `data_response`, `mutation`
- Add Pydantic models:
- `WsHomeRequest(type, message, conversation_history?)`
- `WsFloatingRequest(type, message, scope: {type, id?})`
- `WsStreamStart(type, request_id)`
- `WsStreamText(type, request_id, chunk)`
- `WsStreamBlock(type, request_id, block_type, data)`
- `WsStreamEnd(type, request_id, mutations?)`
- `WsFloatingDomain(type, request_id, domain)`
- Keep all existing frame types (backward compat).
**Files touched**: `app/schemas.py`
**Test**: Unit test that validates each new model serializes/deserializes correctly.
```
pytest tests/test_schemas_v3.py
```
**Status**:
- [x] Step 1 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-1: add v3 ws frame protocol (schemas.py)"
```
---
## Step 2 — Agent Streaming + Tool Result Capture (agent_registry.py, agents/)
**Goal**: Agents can stream LLM tokens and expose structured tool results.
**Changes**:
- `app/core/agent_registry.py`:
- Add `_tool_loop_stream()` to `ChatAgent` — same logic as `_tool_loop()` but the **final** LLM call (when no more tool calls) uses `llm.astream()` and yields tokens.
- Add `self.tool_results: list[dict]` attribute to `ChatAgent.__init__()`.
- In both `_tool_loop` and `_tool_loop_stream`, capture raw `execute_on_client` results when tools run (store in `self.tool_results`).
- `app/agents/*.py` — Each agent's tools already return text summaries. No change to tools. The raw data capture happens at the `_tool_loop` level by intercepting `ToolMessage` content that comes from `execute_on_client`.
**Files touched**: `app/core/agent_registry.py`
**Test**: Unit test with mocked LLM that verifies `_tool_loop_stream()` yields tokens and `agent.tool_results` contains structured data after a tool call.
```
pytest tests/test_agent_streaming.py
```
**Status**:
- [x] Step 2 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-2: add agent streaming and tool result capture (agent_registry.py)"
```
---
## Step 3 — Router Refactor (orchestrator.py)
**Goal**: Orchestrator returns agent name alongside execution, supports streaming.
**Changes**:
- `app/core/orchestrator.py`:
- Add `orchestrate_v3(user_id, message, context, mode)` that:
1. Calls `classify_intent()` (unchanged) -> `agent_name`
2. Instantiates agent via registry
3. Returns `(agent_name, agent_instance)` — caller drives execution
- Add `orchestrate_v3_stream(user_id, message, context)` -> `AsyncGenerator` that:
1. Calls `classify_intent()` -> `agent_name`
2. Calls `agent.handle_stream()` (uses `_tool_loop_stream`)
3. Yields `(agent_name, token)` tuples — first yield includes agent name for domain detection
- Keep `orchestrate()` and `orchestrate_stream()` unchanged (backward compat for POST /chat).
**Files touched**: `app/core/orchestrator.py`
**Test**: Unit test with mocked LLM and mocked registry that verifies `orchestrate_v3_stream` yields `(agent_name, token)` pairs.
```
pytest tests/test_orchestrator_v3.py
```
**Status**:
- [x] Step 3 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-3: add router refactor with streaming support (orchestrator.py)"
```
---
## Step 4 — Output Formatting Layer (NEW: output_formatter.py)
**Goal**: Home and Floating responses diverge at this layer only.
### Block Types (from Electron app components)
The LLM outputs a JSON block stream. Each block has a `type` field that maps to
an Electron renderer component. The server validates and forwards these blocks.
**Text block** — streamed immediately, word-by-word:
```json
{ "type": "text", "content": "Here's your task summary..." }
```
**Chart blocks** — buffered until complete, validated, sent as `stream_block`.
Chart types match shadcn/ui Recharts wrappers used in the Electron app:
```json
{ "type": "chart", "chartType": "<type>", "title": "...", "data": [...], "config": {...} }
```
Supported `chartType` values:
- `area` — Area chart (shadcn AreaChart)
- `bar` — Bar chart (shadcn BarChart)
- `line` — Line chart (shadcn LineChart)
- `pie` — Pie chart (shadcn PieChart)
- `radar` — Radar chart (shadcn RadarChart)
- `radial` — Radial/gauge chart (shadcn RadialChart)
`data` is an array of objects with keys matching the chart's dataKey config.
`config` follows the shadcn ChartConfig format: `{ [dataKey]: { label, color } }`.
**Entity blocks** — server serializes from `agent.tool_results` (not LLM-generated data):
```json
{ "type": "entity_ref", "entity": "task" }
```
The server resolves this by looking up the structured data from the agent's
tool call results and emitting a `stream_block` with the full entity data.
Supported entity types (matching Electron component types):
- `task` — TaskRow component (`TaskItem`: id, title, status, priority, assignee, dueDate, projectId, ...)
- `project` — Project card (id, name, clientId, status)
- `note` — Note card (id, title, createdAt, projectId)
- `timeline` — Timeline card (GanttTimeline: id, title, date, projectId, isAiSuggested, isApproved)
**Table block** — buffered, validated:
```json
{ "type": "table", "headers": ["Col1", "Col2"], "rows": [["val1", "val2"]] }
```
**Timeline block** — buffered, validated (renders via GanttChart component):
```json
{ "type": "timeline", "timelines": [{ "id": "...", "title": "...", "date": 1234567890 }] }
```
### Changes
- `app/core/output_formatter.py` (new file):
- `HomeFormatter`:
- Receives token stream from orchestrator
- Accumulates tokens into a JSON-aware buffer
- Detects block boundaries by `type` field:
- `text` -> yields `WsStreamText` immediately (streams content word-by-word)
- `chart` -> buffers until JSON complete, validates `chartType` against allowed set, yields `WsStreamBlock`
- `entity_ref` -> looks up data from `agent.tool_results`, serializes full entity, yields `WsStreamBlock`
- `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock`
- `timeline` -> buffers, validates timeline objects, yields `WsStreamBlock`
- Invalid blocks are logged and skipped (never crash the stream)
- `FloatingFormatter`:
- Receives `agent_name` from orchestrator
- Maps agent name to domain (deterministic, by code — no LLM):
- `task_agent` -> `"tasks"`
- `timeline_agent` -> `"timelines"`
- `note_agent` -> `"notes"`
- `project_agent` -> `"projects"`
- Yields `WsFloatingDomain` immediately
- Then yields `WsStreamText` for all tokens (text-only, no blocks)
**Files touched**: `app/core/output_formatter.py` (new)
**Test**: Unit test that feeds a mock token stream through each formatter and asserts correct frame output sequence.
```
pytest tests/test_output_formatter.py
```
**Status**:
- [x] Step 4 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-4: add output formatting layer (output_formatter.py)"
```
---
## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py)
**Goal**: Single multiplexed WebSocket handles device frames + Home/Floating chat.
**Changes**:
- `app/api/routes/device_ws.py`:
- Extend `_message_loop` dispatch to handle `home_request` and `floating_request`:
- On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket.
- On `floating_request`: same, but pipe through `FloatingFormatter`.
- Wrap both in try/finally to clear `ws_context`.
- Each request gets a `request_id` (UUID) for frame correlation.
- Concurrent requests from same client are supported (each runs as an async task).
- `app/api/routes/chat.py`:
- Remove `chat_stream` WS endpoint and any related helper functions that were only used by it.
- Keep `POST /chat` endpoint unchanged (REST fallback).
- Clean up any unused imports.
- `app/main.py`:
- No change needed (device_ws router already registered).
**Files touched**: `app/api/routes/device_ws.py`, `app/api/routes/chat.py`, `app/main.py`
**Test**: Integration test with a WebSocket test client that:
1. Connects to `/api/v1/ws/device`
2. Sends `device_hello`
3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end`
4. Sends `floating_request` -> receives `floating_domain`, `stream_text`*, `stream_end`
5. Verifies `tool_call`/`tool_result` round-trip still works during chat
```
pytest tests/test_ws_unified.py
```
**Status**:
- [x] Step 5 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-5: unify ws handler (device_ws.py, chat.py)"
```
---
## Step 6 — Memory Models + Migration (models.py, alembic)
**Goal**: Database tables for 4-tier memory, with per-user encryption key.
**Changes**:
- `app/models.py`:
- Add `encryption_key` column to `User` model (Fernet key, generated on registration).
- Add `MemoryCore` model: `id, user_id, key, value_encrypted, updated_at`
- Add `MemoryAssociative` model: `id, user_id, content_encrypted, embedding (Vector(1536)), entity_type, entity_id, updated_at`
- Add `MemoryEpisodic` model: `id, user_id, summary_encrypted, session_id, created_at`
- Add `MemoryProactive` model: `id, user_id, pattern_encrypted, confidence, source, created_at`
- `alembic/versions/` — New migration adding the 4 memory tables + user encryption_key column.
- `app/api/routes/auth.py` — On user registration, generate and store a Fernet key.
**Files touched**: `app/models.py`, `alembic/versions/xxx_add_memory_tables.py`, `app/api/routes/auth.py`
**Test**: Run migration up/down, verify tables exist with correct columns.
```
alembic upgrade head && alembic downgrade -1 && alembic upgrade head
pytest tests/test_memory_models.py
```
**Status**:
- [x] Step 6 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-6: add memory models and migration (models.py, alembic)"
```
---
## Step 7 — Memory Middleware (NEW: memory_middleware.py)
**Goal**: Enrich every Router call with memory context, store interactions after.
**Changes**:
- `app/core/memory_middleware.py` (new file):
- `MemoryMiddleware` class with:
- `enrich_context(user_id, message) -> dict` (pre-LLM):
1. Load core memory (user prefs) — always injected
2. Embed `message`, search `MemoryAssociative` via pgvector — top-k relevant
3. Fetch recent `MemoryEpisodic` entries — last N sessions
4. Fetch active `MemoryProactive` patterns — above confidence threshold
5. Return merged context dict
- `store_episode(user_id, session_id, message, response)` (post-LLM):
1. Summarize interaction (short LLM call or heuristic)
2. Encrypt and store in `MemoryEpisodic`
3. Embed interaction, encrypt and upsert in `MemoryAssociative`
- `update_core(user_id, key, value)` — explicit preference update
- All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key`
- `app/api/routes/device_ws.py` — Update `home_request` and `floating_request` handlers:
- Before orchestrator: `enriched = await memory.enrich_context(user_id, message)`
- After response complete: `await memory.store_episode(user_id, ...)`
**Files touched**: `app/core/memory_middleware.py` (new), `app/api/routes/device_ws.py`
**Test**: Unit test with seeded memory rows that verifies:
1. `enrich_context` returns core prefs + associative matches + episodic summaries
2. `store_episode` creates encrypted rows that can be decrypted with the user's key
3. End-to-end WS test: send `home_request`, verify memory enrichment is passed to orchestrator
```
pytest tests/test_memory_middleware.py
```
**Status**:
- [x] Step 7 complete
**Commit**: After tests pass, commit with:
```
git commit -m "step-7: add memory middleware (memory_middleware.py, device_ws.py)"
```
---
## Summary
| Step | Component | Effort | Depends On |
|------|-----------|--------|------------|
| 1 | WS Frame Protocol | Low | — |
| 2 | Agent Streaming | Medium | Step 1 |
| 3 | Router Refactor | Medium | Step 2 |
| 4 | Output Formatter | High | Steps 1, 3 |
| 5 | Unified WS Handler | High | Steps 14 |
| 6 | Memory Models | Medium | — |
| 7 | Memory Middleware | High | Steps 5, 6 |
Steps 15 form the streaming pipeline. Steps 67 form the memory system.
Step 6 can run in parallel with Steps 24 (no dependencies).

View File

@@ -1,4 +1,4 @@
"""Import all agent modules to trigger @registry.register decorators."""
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
from app.agents import timeline_agent, note_agent, project_agent, task_agent

View File

@@ -1,31 +1,14 @@
"""Note agent — Markdown note management (list, get, create, update, delete)."""
"""Note agent — tool definitions for Markdown note CRUD."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import embed, get_llm
from app.core.llm import embed
from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool
async def list_notes(project_id: str = "") -> str:
@@ -122,23 +105,4 @@ async def delete_note(note_id: str) -> str:
return f"Note {note_id} deleted."
@registry.register
class NoteAgent(ChatAgent):
def get_name(self) -> str:
return "note_agent"
def get_description(self) -> str:
return "Manages notes: list, get, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_notes, get_note, create_note, update_note, delete_note]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -1,33 +1,13 @@
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
"""Project agent — tool definitions for project lifecycle CRUD."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool
async def list_projects(
@@ -137,30 +117,4 @@ async def delete_project(project_id: str) -> str:
return f"Project {project_id} permanently deleted."
@registry.register
class ProjectAgent(ChatAgent):
def get_name(self) -> str:
return "project_agent"
def get_description(self) -> str:
return "Manages projects: list, get, create, update, archive, delete"
def get_tools(self) -> list[Any]:
return [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -1,35 +1,14 @@
"""Task agent — full CRUD for tasks and task comments."""
"""Task agent — tool definitions for task and task comment CRUD."""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ────────────────────────────────────────────────────────
@@ -220,35 +199,4 @@ async def delete_task_comment(comment_id: str) -> str:
return f"Comment {comment_id} deleted."
# ── Agent ─────────────────────────────────────────────────────────────
@registry.register
class TaskAgent(ChatAgent):
def get_name(self) -> str:
return "task_agent"
def get_description(self) -> str:
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
def get_tools(self) -> list[Any]:
return [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -1,30 +1,13 @@
"""Timeline agent — project milestone management (list, create, update, delete)."""
"""Timeline agent — tool definitions for project milestone CRUD."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_approved: 0 until the user explicitly confirms; then 1\n"
" - For update_timeline, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all timelines across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_timelines(project_id: str = "") -> str:
@@ -106,23 +89,4 @@ async def delete_timeline(timeline_id: str) -> str:
return f"Timeline {timeline_id} deleted."
@registry.register
class TimelineAgent(ChatAgent):
def get_name(self) -> str:
return "timeline_agent"
def get_description(self) -> str:
return "Manages project timelines (milestones): list, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_timelines, create_timeline, update_timeline, delete_timeline]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -9,8 +9,10 @@ from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from app.api.deps import get_current_user
from app.core.orchestrator import orchestrate
from app.schemas import ChatRequest, UserProfile
from app.core.deep_agent import run_home
from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
from app.schemas import ChatRequest, ChatResponse, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"])
@@ -20,10 +22,21 @@ async def chat(
body: ChatRequest,
current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse:
"""Route a chat message through the orchestrator.
"""Route a chat message through the Home deep agent (non-streaming)."""
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(current_user.id, body.message)
Returns ``ChatResponse`` for ``execution_mode='direct'``,
or ``ExecutionPlan`` for ``execution_mode='plan'``.
"""
result = await orchestrate(body)
context = {
**body.context.model_dump(),
**memory_context,
}
response_text = await run_home(
user_id=current_user.id,
message=body.message,
context=context,
db_session_factory=async_session,
)
result = ChatResponse(response=response_text)
return JSONResponse(content=result.model_dump())

View File

@@ -43,7 +43,7 @@ from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
from app.core.orchestrator import orchestrate_v3_stream
from app.core.deep_agent import run_home_stream, run_floating_stream
from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
@@ -200,13 +200,35 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
# ── v3 Chat Handlers ──────────────────────────────────────────────────
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
async def _make_ws_executor(websocket: WebSocket, user_id: str):
"""Return a callback that sends tool_call frames and awaits tool_result."""
async def _executor(payload: dict) -> dict:
payload["type"] = WsFrameType.tool_call
call_id = payload["id"]
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
await websocket.send_text(json.dumps(payload))
future = device_manager.create_pending_call(user_id, payload["id"])
return await future
future = device_manager.create_pending_call(user_id, call_id)
try:
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
except asyncio.TimeoutError:
logger.error(
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
call_id, payload.get("action"), user_id,
)
# Clean up the pending future so it doesn't leak
conn = device_manager._connections.get(user_id)
if conn:
conn.pending_calls.pop(call_id, None)
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
call_id, type(result).__name__,
list(result.keys()) if isinstance(result, dict) else "N/A")
if result is None:
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
return result
return _executor
@@ -233,21 +255,13 @@ async def _handle_home_request(
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
agent_holder: list = []
try:
token_stream = orchestrate_v3_stream(
user_id, message, context, agent_holder=agent_holder
event_stream = run_home_stream(
user_id, message, context, db_session_factory=async_session
)
formatter = HomeFormatter(request_id=request_id, tool_results=[])
async for ws_frame in formatter.format(token_stream):
# Inject mutations from agent tool_results into stream_end
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
ws_frame.mutations = [ # type: ignore[union-attr]
{"action": r["action"], "table": r["table"], "data": r["data"]}
for r in getattr(agent_holder[0], "tool_results", [])
]
formatter = HomeFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
@@ -287,18 +301,13 @@ async def _handle_floating_request(
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
agent_holder: list = []
try:
token_stream = orchestrate_v3_stream(
user_id, message, context, agent_holder=agent_holder
event_stream = run_floating_stream(
user_id, message, context, scope=scope,
db_session_factory=async_session,
)
formatter = FloatingFormatter(request_id=request_id)
async for ws_frame in formatter.format(token_stream):
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
ws_frame.mutations = [ # type: ignore[union-attr]
{"action": r["action"], "table": r["table"], "data": r["data"]}
for r in getattr(agent_holder[0], "tool_results", [])
]
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]

View File

@@ -1,37 +0,0 @@
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from app.api.deps import get_current_user
from app.core.execution_plan import plan_cache
from app.schemas import ExecutionPlan, UserProfile
router = APIRouter(prefix="/plans", tags=["plans"])
@router.get("/playbook", response_model=list[ExecutionPlan])
async def list_playbooks(
current_user: UserProfile = Depends(get_current_user),
) -> list[ExecutionPlan]:
"""Return all cached execution plan playbooks for the authenticated user.
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
"""
return plan_cache.get_all_playbooks()
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
async def get_playbook(
plan_id: str,
current_user: UserProfile = Depends(get_current_user),
) -> ExecutionPlan:
"""Return a specific execution plan playbook by ID."""
plan = plan_cache.get_plan(plan_id)
if plan is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Plan not found: {plan_id}",
)
return plan

View File

@@ -1,217 +0,0 @@
"""Agent Registry — base classes and singleton registry for chat agents."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any
class BaseAgent(ABC):
"""Common base for all agents."""
def __init__(
self,
user_id: str = "",
shared_memory: dict[str, Any] | None = None,
vector_store_context: list[str] | None = None,
) -> None:
self.user_id = user_id
self.shared_memory: dict[str, Any] = shared_memory or {}
self.vector_store_context: list[str] = vector_store_context or []
@abstractmethod
def get_name(self) -> str: ...
@abstractmethod
def get_description(self) -> str: ...
@property
def skills(self) -> list[str]:
"""Override in subclasses to advertise capabilities."""
return []
class ChatAgent(BaseAgent):
"""Base class for LLM-powered chat agents."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
self.tool_results: list[dict] = []
@abstractmethod
async def handle(self, query: str, context: dict[str, Any]) -> str:
"""Process a user query and return a text response."""
...
async def handle_stream(
self, query: str, context: dict[str, Any]
) -> AsyncGenerator[str, None]:
"""Streaming variant of handle().
Default: calls handle() and yields the full response as one chunk.
Override in subclasses for true token-level streaming via _tool_loop_stream.
"""
yield await self.handle(query, context)
@abstractmethod
def get_tools(self) -> list[Any]:
"""Return LangChain tool definitions available to this agent."""
...
async def _tool_loop(
self,
llm: Any,
messages: list[Any],
tools: list[Any],
max_iter: int = 5,
) -> str:
"""Shared tool-calling loop.
Binds *tools* to *llm*, invokes iteratively until the model stops
requesting tool calls or *max_iter* is reached, and returns the
final text response. Captures raw execute_on_client results in
``self.tool_results``.
"""
from langchain_core.messages import AIMessage, ToolMessage
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
collector: list[dict] = []
set_tool_result_collector(collector)
try:
llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return str(response.content)
# Execute each requested tool call
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — ask model for a final answer without tools
response = await llm.ainvoke(messages)
return str(response.content)
finally:
clear_tool_result_collector()
self.tool_results = collector
async def _tool_loop_stream(
self,
llm: Any,
messages: list[Any],
tools: list[Any],
max_iter: int = 5,
) -> AsyncGenerator[str, None]:
"""Streaming variant of ``_tool_loop``.
Behaves identically for tool-calling iterations (uses ainvoke to parse
tool calls). For the final response — when the model produces no further
tool calls — switches to ``llm.astream()`` and yields text tokens.
Captures raw execute_on_client results in ``self.tool_results``.
"""
from langchain_core.messages import AIMessage, ToolMessage
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
collector: list[dict] = []
set_tool_result_collector(collector)
try:
llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter):
response: AIMessage = await llm_with_tools.ainvoke(messages)
if not response.tool_calls:
# Stream the final answer — don't keep the ainvoke result.
async for chunk in llm.astream(messages):
if chunk.content:
yield str(chunk.content)
return
messages.append(response)
# Execute each requested tool call
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — stream a final answer without tools
async for chunk in llm.astream(messages):
if chunk.content:
yield str(chunk.content)
finally:
clear_tool_result_collector()
self.tool_results = collector
class AgentRegistry:
"""Singleton registry for ChatAgent subclasses."""
_instance: AgentRegistry | None = None
def __init__(self) -> None:
self._agents: dict[str, type[ChatAgent]] = {}
def __new__(cls) -> AgentRegistry:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._agents = {}
return cls._instance
# ── public API ───────────────────────────────────────────────────
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
"""Class decorator — registers an agent by its name."""
instance = agent_class()
name = instance.get_name()
self._agents[name] = agent_class
return agent_class
def get(self, name: str) -> ChatAgent:
"""Return a fresh instance of the named agent."""
cls = self._agents.get(name)
if cls is None:
raise KeyError(f"Agent not found: {name}")
return cls()
def list_agents(self) -> list[dict[str, str]]:
"""Return ``[{name, description}]`` for the orchestrator prompt."""
result: list[dict[str, str]] = []
for cls in self._agents.values():
inst = cls()
result.append(
{"name": inst.get_name(), "description": inst.get_description()}
)
return result
async def call_agent(
self, name: str, query: str, context: dict[str, Any]
) -> str:
"""Instantiate the named agent and call its ``handle`` method."""
agent = self.get(name)
return await agent.handle(query, context)
# Module-level singleton
registry = AgentRegistry()

View File

@@ -1,4 +1,4 @@
"""Agent run orchestrator.
"""Agent run manager.
Drives two agent types:

489
app/core/deep_agent.py Normal file
View File

@@ -0,0 +1,489 @@
"""Deep Agent — ``create_deep_agent`` supervisors for home and floating modes.
Two supervisor graphs (via ``deepagents.create_deep_agent``):
* **HomeSupervisor** — gathers data from multiple domains, presents
structured overview with entity/chart tags.
* **FloatingSupervisor** — focused, scoped assistant for a single entity/domain.
Each supervisor delegates to four sub-agents (task, project, note, timeline)
via the built-in ``task`` tool provided by ``SubAgentMiddleware``.
The sub-agents talk to Electron via ``execute_on_client``.
Built-in middleware provides: todo-list tracking, virtual filesystem,
automatic context summarisation, prompt-caching, and tool-call patching.
Streaming uses ``astream(stream_mode=["messages", "updates"])`` so that
callers can sniff:
* ``("messages", (token, metadata))`` — text tokens for streaming
* ``("updates", ...)`` — tool call results for mutations
An ``update_core_memory`` tool is available to both supervisors for
persisting user preferences mid-conversation (MemGPT-style).
"""
from __future__ import annotations
import json
import logging
from typing import Any, AsyncGenerator
from deepagents import create_deep_agent
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
from langchain_core.tools import tool
from app.core.llm import get_llm
from app.core.ws_context import (
clear_tool_result_collector,
set_tool_result_collector,
)
logger = logging.getLogger(__name__)
# ── Sub-agent tool imports ────────────────────────────────────────────
from app.agents.task_agent import ( # noqa: E402
add_task_comment,
create_task,
delete_task,
delete_task_comment,
list_task_comments,
list_tasks,
list_tasks_due_today,
update_task,
)
from app.agents.note_agent import ( # noqa: E402
create_note,
delete_note,
get_note,
list_notes,
update_note,
)
from app.agents.project_agent import ( # noqa: E402
create_project,
delete_project,
get_project,
list_all_projects,
list_projects,
update_project,
)
from app.agents.timeline_agent import ( # noqa: E402
create_timeline,
delete_timeline,
list_timelines,
update_timeline,
)
# ── Sub-agent definitions ─────────────────────────────────────────────
_TASK_TOOLS = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]
_NOTE_TOOLS = [list_notes, get_note, create_note, update_note, delete_note]
_PROJECT_TOOLS = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]
_TIMELINE_TOOLS = [list_timelines, create_timeline, update_timeline, delete_timeline]
def _make_subagent_specs() -> list[dict[str, Any]]:
"""Return SubAgent dicts for the four workspace domains.
Each dict follows the ``deepagents`` ``SubAgent`` TypedDict:
name, description, system_prompt, tools, model
The model and middleware are filled in by ``create_deep_agent`` automatically.
"""
llm = get_llm()
return [
{
"name": "task_agent",
"description": (
"Manages tasks and comments: list, create, update, delete, "
"due-today, and comments. Use when the user asks about tasks, "
"to-dos, assignments, deadlines, or anything task-related."
),
"system_prompt": (
"You are a task management assistant. You create, update, list, "
"and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds\n"
" - assignees is a JSON-encoded array of strings\n"
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
),
"tools": _TASK_TOOLS
},
{
"name": "note_agent",
"description": (
"Manages notes: list, get, create, update, delete. "
"Use when the user asks about notes, documents, or written content."
),
"system_prompt": (
"You are a note-taking assistant. You help users create, retrieve, "
"update, and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - When updating, call get_note first if you need to read existing "
"content before appending or replacing sections\n"
" - Do not fabricate note content."
),
"tools": _NOTE_TOOLS
},
{
"name": "project_agent",
"description": (
"Manages projects: list, get, create, update, archive, delete. "
"Use when the user asks about projects, workspaces, or project status."
),
"system_prompt": (
"You are a project management assistant. You help users create, "
"find, update, and archive projects.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - Prefer archiving over deletion\n"
" - ai_summary is populated only when the user asks for a summary."
),
"tools": _PROJECT_TOOLS
},
{
"name": "timeline_agent",
"description": (
"Manages project timelines and milestones: list, create, update, "
"delete. Use when the user asks about timelines, milestones, "
"deadlines, or project scheduling."
),
"system_prompt": (
"You are a project timeline assistant. Timelines are milestone "
"dates that track progress on a project.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create\n"
" - date is a Unix timestamp in milliseconds\n"
" - For update_timeline, use -1 for integer fields you do not "
"want to change."
),
"tools": _TIMELINE_TOOLS
},
]
# ── Update core memory tool ──────────────────────────────────────────
def _make_update_core_memory_tool(user_id: str, db_session_factory):
"""Create a tool that persists a key/value preference in core memory."""
@tool
async def update_core_memory(key: str, value: str) -> str:
"""Save a user preference or fact to long-term core memory.
key: short label for the memory (e.g. 'preferred_language', 'timezone')
value: the value to remember
Use this when the user states a preference or fact worth remembering.
"""
from app.core.memory_middleware import MemoryMiddleware
async with db_session_factory() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, key, value)
return f"Remembered: {key} = {value}"
return update_core_memory
# ── System prompts ────────────────────────────────────────────────────
_HOME_SYSTEM = (
"You are Adiuva, a smart workspace assistant on the Home dashboard.\n"
"Your job is to help the user by gathering data from their workspace and "
"presenting a comprehensive overview.\n\n"
"You have sub-agents (task_agent, note_agent, project_agent, "
"timeline_agent) accessible via the `task` tool. Delegate to "
"the appropriate sub-agent(s) based on the user's request. You can call "
"multiple sub-agents in parallel if needed.\n\n"
"You also have an update_core_memory tool — use it when the user states "
"a preference or important fact worth remembering long-term.\n\n"
"IMPORTANT: You do NOT have direct access to workspace data. Always "
"delegate to your subagents using the task() tool. Do not attempt to "
"answer workspace queries yourself — the subagents have the tools to "
"fetch and modify data. You can call multiple subagents in parallel "
"when the request spans multiple domains.\n\n"
"## Entity References\n"
"When your response mentions specific workspace entities, embed them "
"inline using entity tags so the UI can render interactive components.\n"
"Format: <type>[comma-separated UUIDs]</type>\n"
"Supported types: task, project, note, timeline\n\n"
"Example response:\n"
" Here is your project:\n"
" <project>[abc-123-def]</project>\n"
" It has these pending tasks:\n"
" <task>[def-456,ghi-789]</task>\n\n"
"IMPORTANT: Only include IDs of entities that are directly relevant to "
"the user's question. Do NOT dump all entity IDs returned by a tool — "
"filter to only the ones the user asked about or that matter for the answer.\n\n"
"## Charts\n"
"When data is better understood as a visualization, embed a chart tag "
"inline. The frontend renders it using shadcn/ui Recharts components.\n"
"Format: <chart>{{JSON}}</chart>\n\n"
"JSON shape:\n"
' {{"chartType":"<type>","title":"...","data":[...],"config":{{...}}}}\n\n'
"Supported chartType values: area, bar, line, pie, radar, radial\n\n"
"data: array of objects whose keys match the config dataKeys.\n"
"config: {{ dataKey: {{ label, color }} }} — follows shadcn ChartConfig.\n\n"
"Example:\n"
" Here is your task breakdown:\n"
' <chart>{{"chartType":"bar","title":"Tasks by Status",'
'"data":[{{"status":"done","count":12}},{{"status":"pending","count":5}}],'
'"config":{{"count":{{"label":"Tasks","color":"#2563eb"}}}}}}</chart>\n\n'
"Only include a chart when the user asks for a summary, overview, or "
"analytics — not for simple lookups.\n\n"
"Memory context:\n{memory_context}"
)
_FLOATING_SYSTEM = (
"You are Adiuva, a focused workspace assistant in the floating panel.\n"
"The user is currently working in the '{scope_type}' section"
"{scope_detail}.\n\n"
"You have sub-agents (task_agent, note_agent, project_agent, "
"timeline_agent) accessible via the `task` tool. Focus your "
"help on the user's current scope, but you can use other sub-agents "
"if the request requires it.\n\n"
"You also have an update_core_memory tool — use it when the user states "
"a preference or important fact worth remembering long-term.\n\n"
"IMPORTANT: You do NOT have direct access to workspace data. Always "
"delegate to your subagents using the task() tool. Do not attempt to "
"answer workspace queries yourself — the subagents have the tools to "
"fetch and modify data.\n\n"
"Provide direct, conversational responses.\n\n"
"Memory context:\n{memory_context}"
)
def _format_memory_context(memory: dict[str, Any]) -> str:
"""Format the memory dict into a readable string for the system prompt."""
if not memory:
return "(no memory available)"
parts = []
if memory.get("core_memory"):
parts.append("Preferences: " + json.dumps(memory["core_memory"]))
if memory.get("associative_memory"):
parts.append("Related memories: " + "; ".join(memory["associative_memory"][:3]))
if memory.get("episodic_memory"):
parts.append("Recent sessions: " + "; ".join(memory["episodic_memory"][:3]))
if memory.get("proactive_hints"):
parts.append("Patterns: " + "; ".join(memory["proactive_hints"][:3]))
return "\n".join(parts) if parts else "(no memory available)"
# ── Graph builders ────────────────────────────────────────────────────
def build_home_graph(
user_id: str,
memory_context: dict[str, Any],
db_session_factory,
):
"""Build the Home supervisor graph."""
subagent_specs = _make_subagent_specs()
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
prompt = _HOME_SYSTEM.format(
memory_context=_format_memory_context(memory_context),
)
return create_deep_agent(
model=get_llm(),
tools=[memory_tool],
system_prompt=prompt,
subagents=subagent_specs,
name="home_supervisor",
)
def build_floating_graph(
user_id: str,
memory_context: dict[str, Any],
scope: dict[str, Any],
db_session_factory,
):
"""Build the Floating supervisor graph."""
subagent_specs = _make_subagent_specs()
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
scope_type = scope.get("type", "general")
scope_id = scope.get("id")
scope_detail = f" (id: {scope_id})" if scope_id else ""
prompt = _FLOATING_SYSTEM.format(
scope_type=scope_type,
scope_detail=scope_detail,
memory_context=_format_memory_context(memory_context),
)
return create_deep_agent(
model=get_llm(),
tools=[memory_tool],
system_prompt=prompt,
subagents=subagent_specs,
name="floating_supervisor",
)
# ── Stream event type ────────────────────────────────────────────────
# Events yielded by run_*_stream:
# ("token", str) — text token for streaming
# ("tool_start", dict) — {"name": "task_agent", "args": {...}}
# ("tool_end", dict) — {"name": "task_agent", "result": "..."}
# ── Stream runners ────────────────────────────────────────────────────
async def _run_graph_stream(
graph,
message: str,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Run a supervisor graph with streaming, yielding event tuples.
Uses ``stream_mode=["messages", "updates"]`` to get both token-level
streaming and update events for tool calls.
"""
inputs = {"messages": [HumanMessage(content=message)]}
collector: list[dict] = []
set_tool_result_collector(collector)
try:
async for stream_mode, chunk in graph.astream(
inputs,
stream_mode=["messages", "updates"],
):
if stream_mode == "messages":
msg, metadata = chunk
agent_name = (
metadata.get("lc_agent_name", "?")
if isinstance(metadata, dict) else "?"
)
node = (
metadata.get("langgraph_node", "?")
if isinstance(metadata, dict) else "?"
)
# Log every message event with agent attribution
if isinstance(msg, (AIMessage, AIMessageChunk)) and msg.content:
logger.info(
"[%s] %s node=%s content=%s",
agent_name,
type(msg).__name__,
node,
str(msg.content),
)
elif isinstance(msg, (AIMessage, AIMessageChunk)) and msg.tool_calls:
tool_names = [tc["name"] for tc in msg.tool_calls]
logger.info(
"[%s] %s node=%s tool_calls=%s",
agent_name,
type(msg).__name__,
node,
tool_names,
)
elif hasattr(msg, "name") and hasattr(msg, "content") and msg.content:
# ToolMessage — log tool result
logger.info(
"[%s] ToolMessage tool=%s node=%s result=%s",
agent_name,
getattr(msg, "name", "?"),
node,
str(msg.content),
)
# Only yield tokens from the supervisor's final response
# (not from sub-agent internal LLM calls).
# Accept both AIMessageChunk (streamed tokens) and AIMessage
# (full response from non-streaming providers).
# create_deep_agent names the LLM node "model".
if (
isinstance(msg, (AIMessage, AIMessageChunk))
and msg.content
and not msg.tool_calls
and isinstance(metadata, dict)
and metadata.get("langgraph_node") == "model"
):
yield ("token", str(msg.content))
elif stream_mode == "updates":
# Updates is a dict of {node_name: state_update}
if not isinstance(chunk, dict):
continue
for node_name, state_update in chunk.items():
if node_name != "tools":
continue
# Tool node executed — extract tool call results
tool_messages = state_update.get("messages", [])
for tool_msg in tool_messages:
if hasattr(tool_msg, "name") and hasattr(tool_msg, "content"):
yield (
"tool_end",
{"name": tool_msg.name, "result": str(tool_msg.content)},
)
finally:
clear_tool_result_collector()
# Yield the collected mutations so callers can attach them to stream_end
yield ("mutations", collector)
async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
db_session_factory,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Run the Home supervisor and yield streaming events."""
graph = build_home_graph(user_id, context, db_session_factory)
async for event in _run_graph_stream(graph, message):
yield event
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
scope: dict[str, Any],
db_session_factory,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Run the Floating supervisor and yield streaming events."""
graph = build_floating_graph(user_id, context, scope, db_session_factory)
async for event in _run_graph_stream(graph, message):
yield event
async def run_home(
user_id: str,
message: str,
context: dict[str, Any],
db_session_factory,
) -> str:
"""Run the Home supervisor (non-streaming) and return full response text."""
graph = build_home_graph(user_id, context, db_session_factory)
result = await graph.ainvoke(
{"messages": [HumanMessage(content=message)]}
)
messages = result["messages"]
for msg in reversed(messages):
if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None):
return str(msg.content)
return ""

View File

@@ -1,222 +0,0 @@
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
from __future__ import annotations
from collections import OrderedDict
from typing import Any
from app.schemas import ExecutionPlan, PlanStep
# ── Prompt Template Registry ──────────────────────────────────────────
class PromptTemplateRegistry:
"""Server-side store mapping template IDs to prompt text.
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
The actual prompt text is resolved here on the server, keeping prompt IP
out of API responses.
"""
def __init__(self) -> None:
self._templates: dict[str, str] = {}
def register(self, template_id: str, prompt_text: str) -> None:
self._templates[template_id] = prompt_text
def get(self, template_id: str) -> str:
"""Resolve a template ID to its prompt text.
Raises ``KeyError`` if the template is not registered.
"""
text = self._templates.get(template_id)
if text is None:
raise KeyError(f"Template not found: {template_id!r}")
return text
def has(self, template_id: str) -> bool:
return template_id in self._templates
def list_ids(self) -> list[str]:
"""Return all registered template IDs (never the text)."""
return list(self._templates.keys())
# ── Execution Plan Builder ────────────────────────────────────────────
class ExecutionPlanBuilder:
"""Fluent builder for ``ExecutionPlan`` objects.
Example::
plan = (
ExecutionPlanBuilder("task_agent")
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
.add_data_step("create_record", data_from_step=0)
.build()
)
"""
def __init__(self, agent: str) -> None:
self._agent = agent
self._steps: list[PlanStep] = []
# ── step adders ──────────────────────────────────────────────────
def add_step(
self, action: str, params: dict[str, Any] | None = None
) -> ExecutionPlanBuilder:
"""Append a generic action step with optional parameters."""
self._steps.append(PlanStep(action=action, variables=params))
return self
def add_llm_step(
self, template_id: str, variables: dict[str, Any] | None = None
) -> ExecutionPlanBuilder:
"""Append an LLM step referencing a server-side template by ID."""
self._steps.append(
PlanStep(action="llm", prompt_template=template_id, variables=variables)
)
return self
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
"""Append a step whose input comes from the output of an earlier step."""
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
return self
# ── build ────────────────────────────────────────────────────────
def build(self) -> ExecutionPlan:
"""Validate step references and return the ``ExecutionPlan``.
Raises ``ValueError`` if any ``data_from_step`` references a
non-existent or future step index.
"""
for i, step in enumerate(self._steps):
if step.data_from_step is not None:
if not (0 <= step.data_from_step < i):
raise ValueError(
f"Step {i}: data_from_step={step.data_from_step} must "
f"reference a preceding step index in range 0..{i - 1}"
)
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
class PlanCache:
"""In-memory LRU cache for ``ExecutionPlan`` objects.
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
The cache also serves as a runtime memoisation layer so that repeated
identical intent classifications can skip re-building the plan.
"""
def __init__(self, maxsize: int = 1000) -> None:
self._maxsize = maxsize
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
if key in self._cache:
del self._cache[key] # remove so re-insertion places it at the end
elif len(self._cache) >= self._maxsize:
self._cache.popitem(last=False) # evict least-recently-used
self._cache[key] = plan
def get_plan(self, key: str) -> ExecutionPlan | None:
"""Return the cached plan for *key*, or ``None`` if not present.
Accessing a plan marks it as most-recently used.
"""
if key not in self._cache:
return None
self._cache.move_to_end(key)
return self._cache[key]
def get_all_playbooks(self) -> list[ExecutionPlan]:
"""Return all cached plans (most-recently used last)."""
return list(self._cache.values())
# ── Module-level singletons ───────────────────────────────────────────
template_registry = PromptTemplateRegistry()
plan_cache = PlanCache()
def _register_builtin_templates() -> None:
"""Register the built-in server-side prompt templates.
These strings never leave the server. Clients only receive the IDs.
"""
_tpls: dict[str, str] = {
"tpl_task_agent_default": (
"You are a task management assistant. Help the user create, update, "
"list, and track tasks. Use correct status values (todo, in_progress, "
"done) and priority values (high, medium, low) from the workspace model."
),
"tpl_timeline_agent_default": (
"You are a project timeline assistant. Help the user create and manage "
"milestone timelines on their projects. Every timeline requires a "
"project_id and a date expressed as a Unix timestamp in milliseconds."
),
"tpl_project_agent_default": (
"You are a project management assistant. Help the user create, find, "
"update, and archive projects. Projects have a name, an optional client, "
"and a status of either active or archived."
),
"tpl_note_agent_default": (
"You are a note-taking assistant. Help the user create, retrieve, update, "
"and delete Markdown notes. Notes can optionally be linked to a project."
),
"tpl_task_extract_from_project": (
"Extract all actionable tasks from the provided project context. "
"Return a structured list of tasks, each with a title, inferred priority "
"(high, medium, or low), suggested status (todo), and a due_date in "
"milliseconds where a deadline can be inferred."
),
"tpl_note_weekly_summary": (
"Generate a weekly project summary note from the provided workspace data. "
"Include: tasks completed this week, tasks due soon, active projects, "
"and upcoming timelines. Format the output as clean Markdown."
),
}
for tid, text in _tpls.items():
template_registry.register(tid, text)
def _load_playbooks() -> None:
"""Pre-build and cache the built-in playbooks."""
playbooks: list[tuple[str, ExecutionPlan]] = [
(
"create_tasks_from_project",
ExecutionPlanBuilder("project_agent")
.add_llm_step(
"tpl_task_extract_from_project",
{"source": "project_context"},
)
.add_data_step("create_record", data_from_step=0)
.build(),
),
(
"generate_weekly_note",
ExecutionPlanBuilder("note_agent")
.add_llm_step(
"tpl_note_weekly_summary",
{"period": "last_7_days"},
)
.add_data_step("create_record", data_from_step=0)
.build(),
),
]
for key, plan in playbooks:
plan_cache.cache_plan(key, plan)
# Initialise on module load
_register_builtin_templates()
_load_playbooks()

View File

@@ -1,6 +1,6 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_:

View File

@@ -43,7 +43,7 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware:
"""Enrich orchestrator context with memory and persist interactions after."""
"""Enrich agent context with memory and persist interactions after."""
def __init__(self, db: AsyncSession) -> None:
self._db = db
@@ -51,7 +51,7 @@ class MemoryMiddleware:
# ── Public API ────────────────────────────────────────────────────────────
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call.
"""Build memory context dict to inject into the agent before LLM call.
Returns a dict with keys:
core_memory — {key: plaintext_value, ...}

View File

@@ -1,210 +0,0 @@
"""Orchestrator — LLM-based intent router and agent pipeline."""
from __future__ import annotations
import json
from typing import Any, AsyncGenerator
from langchain_core.messages import HumanMessage, SystemMessage
from app.core.agent_registry import AgentRegistry, ChatAgent
from app.core.llm import get_router_llm
from app.core.agent_registry import registry as _default_registry
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
_FALLBACK_AGENT = "task_agent"
_CLASSIFY_SYSTEM = (
"You are an intent classifier. Given the user message and context, decide "
"which agent to route to.\n"
"Available agents: {agents}\n"
"Respond with just the agent name, nothing else."
)
_SYNTHESIZE_HUMAN = (
"Combine the following agent results into one coherent response.\n\n"
"Agent results:\n{results}\n\n"
"Original message: {message}"
)
def _make_llm():
return get_router_llm()
async def classify_intent(
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> str:
"""Use gpt-4o-mini to classify intent and return the matching agent name.
Falls back to ``task_agent`` when the registry is empty or the model
returns a name that is not registered.
"""
agents = reg.list_agents()
if not agents:
return _FALLBACK_AGENT
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
# Truncate context to keep the classification prompt short
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
llm = _make_llm()
response = await llm.ainvoke(
[SystemMessage(content=system), HumanMessage(content=human)]
)
agent_name = str(response.content).strip().lower()
known = {a["name"] for a in agents}
return agent_name if agent_name in known else _FALLBACK_AGENT
async def route_single(
agent_name: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
response_text = await reg.call_agent(agent_name, message, context)
return ChatResponse(response=response_text)
async def route_pipeline(
agent_names: list[str],
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Execute agents sequentially; each agent receives previous results in context.
A final LLM synthesis call merges all results into one coherent response.
"""
previous_results: list[str] = []
for agent_name in agent_names:
ctx = {**context, "previous_results": list(previous_results)}
result = await reg.call_agent(agent_name, message, ctx)
previous_results.append(result)
results_str = "\n\n".join(
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
)
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
llm = _make_llm()
synthesis = await llm.ainvoke([HumanMessage(content=human)])
return ChatResponse(response=str(synthesis.content))
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
"""Build an ``ExecutionPlan`` for the resolved agent.
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
If a default template exists for the agent, an LLM step is emitted;
otherwise a plain ``handle`` action step is used.
"""
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
template_id = f"tpl_{agent_name}_default"
builder = ExecutionPlanBuilder(agent_name)
if template_registry.has(template_id):
builder.add_llm_step(template_id, {"message": message})
else:
builder.add_step("handle", {"message": message})
return builder.build()
async def orchestrate(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> ChatResponse | ExecutionPlan:
"""Main orchestration entry point.
* Classifies the user's intent to select an agent.
* ``execution_mode == 'direct'``: routes to the agent and returns a
``ChatResponse``.
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
resolved agent and a template-ID-only step (prompt IP stays server-side).
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
if request.execution_mode == "direct":
return await route_single(agent_name, request.message, context, reg)
# plan mode — return plan, do not execute
return _build_plan(agent_name, request.message)
async def orchestrate_v3(
user_id: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry | None = None,
) -> tuple[str, ChatAgent]:
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
Classifies intent and instantiates the matching agent. The caller is responsible
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
"""
if reg is None:
reg = _default_registry
agent_name = await classify_intent(message, context, reg)
return agent_name, reg.get(agent_name)
async def orchestrate_v3_stream(
user_id: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry | None = None,
agent_holder: list | None = None,
) -> AsyncGenerator[tuple[str, str], None]:
"""v3 streaming orchestration — yields (agent_name, token) pairs.
The first yield always carries the agent_name with an empty token so that
callers (e.g. FloatingFormatter) can detect the routing domain before any text
tokens arrive.
If *agent_holder* is provided (a list), the agent instance is appended so
callers can access ``agent.tool_results`` after the stream completes.
"""
if reg is None:
reg = _default_registry
agent_name = await classify_intent(message, context, reg)
agent = reg.get(agent_name)
if agent_holder is not None:
agent_holder.append(agent)
yield agent_name, "" # domain signal — no token yet
async for token in agent.handle_stream(message, context):
yield agent_name, token
async def orchestrate_stream(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> AsyncGenerator[str, None]:
"""Streaming orchestration — yields plain text chunks only.
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
wrapping each chunk in a ``text_chunk`` frame and sending the final
``final`` frame once the generator is exhausted.
Agents do not yet support token-level streaming; the full response is
fetched first (which may involve multiple WS round-trips for tool calls),
then emitted in fixed-size chunks.
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
response_text = await reg.call_agent(agent_name, request.message, context)
chunk_size = 50
for i in range(0, len(response_text), chunk_size):
yield response_text[i : i + chunk_size]

View File

@@ -1,19 +1,30 @@
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
FloatingFormatter: produces floating_domain, stream_text, stream_end
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
* ``("token", str)`` — supervisor text token
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
HomeFormatter:
* Streams text tokens as-is → emits ``WsStreamText``
(text may contain inline ``<type>[id,...]</type>`` entity tags
for the frontend to parse and render as interactive components)
* Attaches mutations → injects into ``WsStreamEnd``
FloatingFormatter:
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
* Streams text tokens → emits ``WsStreamText``
* Attaches mutations → injects into ``WsStreamEnd``
"""
from __future__ import annotations
import json
import logging
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import (
WsFloatingDomain,
WsStreamBlock,
WsStreamEnd,
WsStreamStart,
WsStreamText,
@@ -21,10 +32,7 @@ from app.schemas import (
logger = logging.getLogger(__name__)
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
# Map agent name → floating domain
# Map sub-agent tool name → floating domain / entity type
_AGENT_DOMAIN: dict[str, str] = {
"task_agent": "tasks",
"timeline_agent": "timelines",
@@ -32,184 +40,68 @@ _AGENT_DOMAIN: dict[str, str] = {
"project_agent": "projects",
}
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
class HomeFormatter:
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
The LLM is expected to output a newline-delimited sequence of JSON objects,
each with a ``type`` field:
- ``text`` → yields WsStreamText immediately (word-by-word)
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
Invalid or unknown blocks are logged and skipped — stream never crashes.
"""
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
self.request_id = request_id
self.tool_results = tool_results
async def format(
self,
token_stream: AsyncGenerator[tuple[str, str], None],
) -> AsyncGenerator[WsFrame, None]:
yield WsStreamStart(request_id=self.request_id)
buffer = ""
async for _agent_name, token in token_stream:
if not token:
continue
buffer += token
# Flush any complete JSON objects from the buffer
async for frame in self._flush_complete_objects(buffer):
buffer = "" # reset after flush
yield frame
break # only one flush per iteration; rest accumulates
# Flush any remaining content
if buffer.strip():
async for frame in self._flush_complete_objects(buffer, final=True):
yield frame
yield WsStreamEnd(request_id=self.request_id)
async def _flush_complete_objects(
self, text: str, final: bool = False
) -> AsyncGenerator[WsFrame, None]:
"""Try to parse and yield all complete JSON objects from *text*.
Yields nothing if text is incomplete JSON (unless *final* is True,
in which case remaining text is emitted as plain stream_text).
"""
remaining = text.strip()
while remaining:
# Fast path: plain text (not JSON)
if not remaining.startswith("{"):
# Yield as plain text chunk
newline_idx = remaining.find("\n")
if newline_idx == -1:
if final:
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
else:
return # accumulate more
else:
line = remaining[:newline_idx].strip()
remaining = remaining[newline_idx + 1:].strip()
if line:
yield WsStreamText(request_id=self.request_id, chunk=line)
continue
# Try to decode a JSON object
try:
obj, end_idx = _try_parse_json(remaining)
except ValueError:
if final:
# Emit as raw text if we can't parse
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
return
if obj is None:
if final:
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
return # incomplete — need more tokens
remaining = remaining[end_idx:].strip()
block_type = obj.get("type")
frame = self._dispatch_block(obj, block_type)
if frame is not None:
yield frame
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
if block_type == "text":
content = obj.get("content", "")
if content:
return WsStreamText(request_id=self.request_id, chunk=str(content))
return None
if block_type == "chart":
chart_type = obj.get("chartType")
if chart_type not in _VALID_CHART_TYPES:
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
return None
if not isinstance(obj.get("data"), list):
logger.warning("HomeFormatter: chart missing data array — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="chart",
data=obj,
)
if block_type == "entity_ref":
entity = obj.get("entity")
resolved = self._resolve_entity(entity)
if resolved is None:
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="entity_ref",
data={"entity": entity, "items": resolved},
)
if block_type == "table":
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
logger.warning("HomeFormatter: table missing headers/rows — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="table",
data=obj,
)
if block_type == "timeline":
if not isinstance(obj.get("timelines"), list):
logger.warning("HomeFormatter: timeline missing timelines — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="timeline",
data=obj,
)
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
return None
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
"""Find matching items in tool_results by entity type."""
if not entity:
return None
matches = [r for r in self.tool_results if r.get("entity") == entity]
return matches if matches else None
class FloatingFormatter:
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
Emits floating_domain immediately (from agent_name), then streams all tokens
as plain stream_text — no block parsing for floating context.
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
is responsible for parsing those and rendering interactive components.
Mutations are attached to ``WsStreamEnd``.
"""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._mutations: list[dict] = []
async def format(
self,
token_stream: AsyncGenerator[tuple[str, str], None],
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
yield WsStreamStart(request_id=self.request_id)
async for event_type, data in event_stream:
if event_type == "token":
if data:
yield WsStreamText(request_id=self.request_id, chunk=data)
elif event_type == "mutations":
self._mutations = data or []
yield WsStreamEnd(
request_id=self.request_id,
mutations=[
{"action": m["action"], "table": m["table"], "data": m["data"]}
for m in self._mutations
],
)
class FloatingFormatter:
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
``task_agent`` → ``"tasks"``), then streams text tokens as plain
``WsStreamText``. No block parsing for floating context.
"""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._mutations: list[dict] = []
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
domain_sent = False
async for agent_name, token in token_stream:
if not domain_sent:
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
async for event_type, data in event_stream:
if event_type == "tool_end" and not domain_sent:
# Sniff domain from the first sub-agent that completes
name = data.get("name", "")
domain = _AGENT_DOMAIN.get(name, "tasks")
yield WsFloatingDomain(
request_id=self.request_id,
domain=domain, # type: ignore[arg-type]
@@ -217,28 +109,33 @@ class FloatingFormatter:
yield WsStreamStart(request_id=self.request_id)
domain_sent = True
if token:
yield WsStreamText(request_id=self.request_id, chunk=token)
elif event_type == "token":
if not domain_sent:
# First token arrived before any tool_end — default domain
yield WsFloatingDomain(
request_id=self.request_id,
domain="tasks", # type: ignore[arg-type]
)
yield WsStreamStart(request_id=self.request_id)
domain_sent = True
if data:
yield WsStreamText(request_id=self.request_id, chunk=data)
yield WsStreamEnd(request_id=self.request_id)
elif event_type == "mutations":
self._mutations = data or []
# If no events triggered domain_sent (edge case), still emit structure
if not domain_sent:
yield WsFloatingDomain(
request_id=self.request_id,
domain="tasks", # type: ignore[arg-type]
)
yield WsStreamStart(request_id=self.request_id)
# ── helpers ───────────────────────────────────────────────────────────────────
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
"""Attempt to parse the first complete JSON object from *text*.
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
object is incomplete, and raises ``ValueError`` when text is not JSON.
"""
decoder = json.JSONDecoder()
try:
obj, end_idx = decoder.raw_decode(text)
if not isinstance(obj, dict):
raise ValueError("Expected JSON object")
return obj, end_idx
except json.JSONDecodeError as exc:
# Incomplete JSON — need more tokens
if "Unterminated" in str(exc) or exc.pos == len(text):
return None, 0
raise ValueError(str(exc)) from exc
yield WsStreamEnd(
request_id=self.request_id,
mutations=[
{"action": m["action"], "table": m["table"], "data": m["data"]}
for m in self._mutations
],
)

View File

@@ -7,18 +7,21 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
from __future__ import annotations
import logging
from contextvars import ContextVar
from typing import Any, Callable, Coroutine
from uuid import uuid4
logger = logging.getLogger(__name__)
# Holds the execute callback for the current WS session.
# Set by the chat WS handler before the orchestrator runs; cleared after.
# Set by the chat WS handler before the deep agent runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
"_client_executor"
)
# Optional collector that captures raw execute_on_client results.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
# Set by the deep agent tool loop to capture CRUD mutations.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
@@ -81,12 +84,17 @@ async def execute_on_client(
if limit is not None:
payload["limit"] = limit
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
result = await callback(payload)
if result is None:
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
else:
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
collector = _tool_result_collector.get(None)
if collector is not None:
if collector is not None and action in ("insert", "update", "delete"):
collector.append({
"action": action,
"table": table,
"data": result,
"data": data or {},
})
return result

View File

@@ -18,10 +18,7 @@ from app.config.settings import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: initialise DB connection pool and agent registry
from app.core.agent_registry import registry # noqa: F401 — triggers module load
import app.agents # noqa: F401 — triggers @registry.register decorators
# Startup: initialise DB connection pool
yield
# Shutdown: dispose SQLAlchemy connection pool
@@ -51,11 +48,10 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(plans.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1")

View File

@@ -41,41 +41,13 @@ class ChatContext(BaseModel):
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class PlanAction(BaseModel):
type: Literal[
"create_record",
"update_record",
"delete_record",
"index_document",
"send_notification",
]
table: str | None = None
data: dict[str, Any] | None = None
class ChatRequest(BaseModel):
message: str
context: ChatContext = Field(default_factory=ChatContext)
execution_mode: Literal["direct", "plan"] = "direct"
class ChatResponse(BaseModel):
response: str
actions: list[PlanAction] = Field(default_factory=list)
# ── Execution Plans ──────────────────────────────────────────────────
class PlanStep(BaseModel):
action: str
prompt_template: str | None = None
variables: dict[str, Any] | None = None
data_from_step: int | None = None
class ExecutionPlan(BaseModel):
agent: str
steps: list[PlanStep] = Field(default_factory=list)
# ── Backup ───────────────────────────────────────────────────────────
@@ -179,7 +151,6 @@ class WsFrameType(str, Enum):
floating_request = "floating_request"
stream_start = "stream_start"
stream_text = "stream_text"
stream_block = "stream_block"
stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request"
@@ -303,15 +274,6 @@ class WsStreamText(BaseModel):
chunk: str
class WsStreamBlock(BaseModel):
"""Server → Client: structured block (chart, table, entity, timeline)."""
type: Literal[WsFrameType.stream_block] = WsFrameType.stream_block
request_id: str
block_type: Literal["chart", "entity_ref", "table", "timeline"]
data: dict[str, Any]
class WsStreamEnd(BaseModel):
"""Server → Client: signals end of a streaming response."""

View File

@@ -4,6 +4,8 @@ gunicorn>=22.0.0
langchain>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.1.0
langgraph>=0.3.0
deepagents>=0.4.10
litellm>=1.50.0
pydantic>=2.10.0
pydantic-settings>=2.7.0

View File

@@ -1,214 +0,0 @@
"""Unit tests for the agent registry, base classes, and tool loop."""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.core.agent_registry import AgentRegistry, ChatAgent
# ── Helpers ──────────────────────────────────────────────────────────
class _StubAgent(ChatAgent):
"""Minimal concrete agent for testing."""
def get_name(self) -> str:
return "stub"
def get_description(self) -> str:
return "A stub agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"echo: {query}"
class _AnotherAgent(ChatAgent):
def get_name(self) -> str:
return "another"
def get_description(self) -> str:
return "Another stub"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "another"
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _fresh_registry():
"""Reset the singleton between tests."""
AgentRegistry._instance = None
yield
AgentRegistry._instance = None
@pytest.fixture()
def reg() -> AgentRegistry:
return AgentRegistry()
# ── Tests ────────────────────────────────────────────────────────────
class TestRegisterAndGet:
def test_register_decorator(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
agent = reg.get("stub")
assert isinstance(agent, _StubAgent)
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError, match="not found"):
reg.get("nonexistent")
def test_register_multiple(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
reg.register(_AnotherAgent)
assert reg.get("stub").get_name() == "stub"
assert reg.get("another").get_name() == "another"
class TestListAgents:
def test_empty(self, reg: AgentRegistry) -> None:
assert reg.list_agents() == []
def test_list_after_register(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
agents = reg.list_agents()
assert len(agents) == 1
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
def test_list_multiple(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
reg.register(_AnotherAgent)
names = {a["name"] for a in reg.list_agents()}
assert names == {"stub", "another"}
class TestCallAgent:
@pytest.mark.asyncio
async def test_call_agent(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
result = await reg.call_agent("stub", "hello", {})
assert result == "echo: hello"
@pytest.mark.asyncio
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError):
await reg.call_agent("nope", "hi", {})
class TestSingleton:
def test_singleton_identity(self) -> None:
a = AgentRegistry()
b = AgentRegistry()
assert a is b
class TestToolLoop:
@pytest.mark.asyncio
async def test_no_tool_calls(self) -> None:
"""When the LLM responds without tool calls, return content directly."""
agent = _StubAgent()
ai_msg = MagicMock()
ai_msg.content = "final answer"
ai_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(return_value=ai_msg)
result = await agent._tool_loop(llm, [], [])
assert result == "final answer"
@pytest.mark.asyncio
async def test_tool_call_then_answer(self) -> None:
"""LLM requests one tool call, gets result, then answers."""
agent = _StubAgent()
# First response: tool call
tool_call_msg = MagicMock()
tool_call_msg.content = ""
tool_call_msg.tool_calls = [
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
]
# Second response: final answer
final_msg = MagicMock()
final_msg.content = "done"
final_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
# Mock tool
tool = AsyncMock()
tool.name = "my_tool"
tool.ainvoke = AsyncMock(return_value="tool_result")
result = await agent._tool_loop(llm, [], [tool])
assert result == "done"
tool.ainvoke.assert_called_once_with({"x": 1})
@pytest.mark.asyncio
async def test_unknown_tool_handled(self) -> None:
"""Unknown tool names produce an error message instead of crashing."""
agent = _StubAgent()
tool_call_msg = MagicMock()
tool_call_msg.content = ""
tool_call_msg.tool_calls = [
{"id": "call_1", "name": "missing", "args": {}}
]
final_msg = MagicMock()
final_msg.content = "recovered"
final_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
result = await agent._tool_loop(llm, [], [])
assert result == "recovered"
@pytest.mark.asyncio
async def test_max_iter_reached(self) -> None:
"""When max iterations are exhausted, a final no-tools call is made."""
agent = _StubAgent()
# Every response requests a tool call
loop_msg = MagicMock()
loop_msg.content = ""
loop_msg.tool_calls = [
{"id": "call_x", "name": "t", "args": {}}
]
final_msg = MagicMock()
final_msg.content = "gave up"
final_msg.tool_calls = []
tool = AsyncMock()
tool.name = "t"
tool.ainvoke = AsyncMock(return_value="ok")
llm_with_tools = AsyncMock()
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm.ainvoke = AsyncMock(return_value=final_msg)
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
assert result == "gave up"
assert llm_with_tools.ainvoke.call_count == 2

View File

@@ -1,416 +0,0 @@
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.core.agent_registry import ChatAgent, registry
# ── Minimal concrete agent for testing ───────────────────────────────
class _EchoAgent(ChatAgent):
def get_name(self) -> str:
return "_echo"
def get_description(self) -> str:
return "Echo agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return query
# ── Helpers ───────────────────────────────────────────────────────────
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
msg = AIMessage(content=content)
if tool_calls:
msg.tool_calls = tool_calls
else:
msg.tool_calls = []
return msg
def _make_tool(name: str, return_value: Any) -> MagicMock:
t = MagicMock()
t.name = name
t.ainvoke = AsyncMock(return_value=return_value)
return t
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
chunks = []
for tok in tokens:
c = MagicMock()
c.content = tok
chunks.append(c)
return chunks
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
tokens: list[str] = []
async for tok in agent._tool_loop_stream(llm, messages, tools):
tokens.append(tok)
return tokens
# ── tool_results initialised ─────────────────────────────────────────
def test_tool_results_init():
agent = _EchoAgent()
assert agent.tool_results == []
# ── _tool_loop: no tool calls ────────────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_no_tools():
agent = _EchoAgent()
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
assert result == "Hello!"
assert agent.tool_results == []
# ── _tool_loop: with one tool call + result capture ──────────────────
@pytest.mark.asyncio
async def test_tool_loop_captures_tool_results():
agent = _EchoAgent()
# Mock execute_on_client to return structured data via the tool
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
async def fake_executor(payload: dict) -> dict:
return raw_result
# AIMessage with a tool call, then a final answer
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
)
final_msg = _make_ai_message("Here are your tasks.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
llm.ainvoke = AsyncMock(return_value=final_msg)
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
# Patch the tool to actually call execute_on_client
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks")
rows = res.get("rows", [])
return "\n".join(r["title"] for r in rows)
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
result = await agent._tool_loop(
llm, [HumanMessage(content="list my tasks")], [mock_tool]
)
finally:
clear_client_executor()
assert result == "Here are your tasks."
assert len(agent.tool_results) == 1
assert agent.tool_results[0] == raw_result
# ── _tool_loop: tool_results reset on each call ──────────────────────
@pytest.mark.asyncio
async def test_tool_loop_resets_tool_results():
agent = _EchoAgent()
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
assert agent.tool_results == []
# ── _tool_loop: unknown tool name ────────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_unknown_tool():
agent = _EchoAgent()
# No known tools — model still calls a non-existent one; loop handles gracefully
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
)
final_msg = _make_ai_message("Handled.")
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
assert result == "Handled."
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_max_iter():
agent = _EchoAgent()
always_tool = _make_ai_message(
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
)
fallback = _make_ai_message("Fallback.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
# Returns tool_call_msg on every iteration
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
llm.ainvoke = AsyncMock(return_value=fallback)
mock_tool = _make_tool("t", "ok")
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
assert result == "Fallback."
assert llm_with_tools.ainvoke.call_count == 2
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_no_tools_yields_tokens():
agent = _EchoAgent()
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
no_tool_msg = _make_ai_message("irrelevant")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
for tok in ["Hello", " ", "world"]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
assert tokens == ["Hello", " ", "world"]
assert agent.tool_results == []
# ── _tool_loop_stream: one tool call then streaming final ─────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_with_tool_call():
agent = _EchoAgent()
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
async def fake_executor(payload: dict) -> dict:
return raw_result
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
)
# After tools run, ainvoke returns no more tool calls
no_more_tools_msg = _make_ai_message("Task found.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
async def fake_astream(msgs):
for tok in ["Task", " ", "found."]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
return res.get("row", {}).get("title", "")
mock_tool = _make_tool("get_task", "Deploy")
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
)
finally:
clear_client_executor()
assert tokens == ["Task", " ", "found."]
assert len(agent.tool_results) == 1
assert agent.tool_results[0] == raw_result
# ── _tool_loop_stream: tool_results reset on each call ───────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_resets_tool_results():
agent = _EchoAgent()
agent.tool_results = [{"old": True}]
no_tool_msg = _make_ai_message("")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
c = MagicMock()
c.content = "ok"
yield c
llm.astream = fake_astream
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
assert agent.tool_results == []
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_skips_empty_chunks():
agent = _EchoAgent()
no_tool_msg = _make_ai_message("")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
for tok in ["", "hello", "", " world", ""]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
assert tokens == ["hello", " world"]
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
@pytest.mark.asyncio
async def test_tool_loop_stream_max_iter():
agent = _EchoAgent()
always_tool = _make_ai_message(
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
)
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
async def fake_astream(msgs):
c = MagicMock()
c.content = "fallback"
yield c
llm.astream = fake_astream
mock_tool = _make_tool("t", "ok")
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="x")], [mock_tool],
)
assert tokens == ["fallback"]
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
# ── _tool_loop_stream: multiple tool results captured ────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_multiple_tool_results():
agent = _EchoAgent()
call_results = [
{"rows": [{"id": "t-1"}]},
{"rows": [{"id": "t-2"}]},
]
call_iter = iter(call_results)
async def fake_executor(payload: dict) -> dict:
return next(call_iter)
# Two tool calls in one iteration
tool_call_msg = _make_ai_message(
tool_calls=[
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
]
)
no_more_tools_msg = _make_ai_message("Done.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
async def fake_astream(msgs):
c = MagicMock()
c.content = "Done."
yield c
llm.astream = fake_astream
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks")
return str(res)
tool_a = _make_tool("tool_a", "")
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
tool_b = _make_tool("tool_b", "")
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
)
finally:
clear_client_executor()
assert tokens == ["Done."]
assert len(agent.tool_results) == 2
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}

View File

@@ -1,761 +0,0 @@
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import app.agents # noqa: F401 — triggers @registry.register decorators
from app.agents.timeline_agent import TimelineAgent
from app.agents.note_agent import NoteAgent
from app.agents.project_agent import ProjectAgent
from app.agents.task_agent import TaskAgent
from app.core.agent_registry import registry
from app.core.ws_context import clear_client_executor, set_client_executor
# ── WS executor mock ──────────────────────────────────────────────────
#
# Tools call execute_on_client() which reads a ContextVar set by the WS
# handler. In unit tests there is no WS session, so we install a fake
# executor that returns plausible data for each action type.
_FAKE_ROW: dict[str, Any] = {
"id": "fake-id",
"title": "Fake Title",
"name": "Fake Name",
"status": "todo",
"priority": "medium",
"content": "Fake content",
"date": 1700000000000,
"taskId": "fake-task-id",
"author": "Alice",
"projectId": None,
}
async def _fake_executor(payload: dict) -> dict:
action = payload.get("action", "")
if action == "select":
return {"rows": []}
if action == "insert":
data = payload.get("data", {})
return {"row": {**_FAKE_ROW, **data}}
if action == "update":
data = payload.get("data", {})
row = {**_FAKE_ROW, "id": data.get("id", "fake-id"), **data.get("updates", {})}
return {"row": row}
if action == "delete":
return {"deleted": True}
if action == "get":
data = payload.get("data", {})
return {"row": {**_FAKE_ROW, "id": data.get("id", "fake-id")}}
if action == "vector_upsert":
return {"ok": True}
return {}
@pytest.fixture(autouse=True)
def ws_executor():
"""Install a fake WS executor for every test so tools can run without a real WS."""
set_client_executor(_fake_executor)
yield
clear_client_executor()
# ── Helpers ──────────────────────────────────────────────────────────
def _mock_llm(response_text: str) -> MagicMock:
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
msg = MagicMock()
msg.content = response_text
msg.tool_calls = []
llm = MagicMock()
bound = MagicMock()
bound.ainvoke = AsyncMock(return_value=msg)
llm.bind_tools = MagicMock(return_value=bound)
llm.ainvoke = AsyncMock(return_value=msg)
return llm
def _mock_llm_with_tool_call(
tool_name: str, tool_args: dict[str, Any], final_text: str
) -> MagicMock:
"""Mock LLM that fires one tool call then returns *final_text*."""
tool_msg = MagicMock()
tool_msg.content = ""
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
final_msg = MagicMock()
final_msg.content = final_text
final_msg.tool_calls = []
bound = MagicMock()
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
llm = MagicMock()
llm.bind_tools = MagicMock(return_value=bound)
llm.ainvoke = AsyncMock(return_value=final_msg)
return llm
# ── Registration ──────────────────────────────────────────────────────
class TestAgentRegistration:
def test_all_agents_registered(self) -> None:
names = {a["name"] for a in registry.list_agents()}
assert {
"task_agent", "timeline_agent", "project_agent", "note_agent"
}.issubset(names)
def test_registry_returns_correct_types(self) -> None:
assert isinstance(registry.get("task_agent"), TaskAgent)
assert isinstance(registry.get("timeline_agent"), TimelineAgent)
assert isinstance(registry.get("project_agent"), ProjectAgent)
assert isinstance(registry.get("note_agent"), NoteAgent)
def test_descriptions_present(self) -> None:
for agent_info in registry.list_agents():
assert agent_info["description"], f"Empty description: {agent_info['name']}"
# ── TaskAgent ─────────────────────────────────────────────────────────
class TestTaskAgent:
def test_name(self) -> None:
assert TaskAgent().get_name() == "task_agent"
def test_description(self) -> None:
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
def test_get_tools_count(self) -> None:
assert len(TaskAgent().get_tools()) == 8
def test_tool_names(self) -> None:
names = {t.name for t in TaskAgent().get_tools()}
assert names == {
"list_tasks",
"create_task",
"update_task",
"delete_task",
"list_tasks_due_today",
"list_task_comments",
"add_task_comment",
"delete_task_comment",
}
@pytest.mark.asyncio
async def test_handle_returns_string(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Task created.")
result = await TaskAgent().handle("create a task", {})
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Here are your tasks.")
result = await TaskAgent().handle("list my tasks", {})
assert result == "Here are your tasks."
@pytest.mark.asyncio
async def test_handle_with_create_task_tool_call(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_task",
{"title": "Buy groceries", "priority": "low"},
"Task 'Buy groceries' created.",
)
result = await TaskAgent().handle("add a grocery task", {})
assert result == "Task 'Buy groceries' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await TaskAgent().handle("help", {})
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_handle_accepts_rich_context(self) -> None:
context = {
"user_profile": {"id": "u1", "tier": "pro"},
"recent_tasks": [{"id": "t1", "title": "Old task"}],
}
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Tasks listed.")
result = await TaskAgent().handle("show tasks", context)
assert isinstance(result, str)
class TestTaskAgentTools:
@pytest.mark.asyncio
async def test_list_tasks_defaults(self) -> None:
from app.agents.task_agent import list_tasks
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_tasks.ainvoke({})
m.assert_called_once_with(
action="select", table="tasks",
filters={"projectId": None, "status": None, "search": None, "orderBy": None},
)
assert result == "No tasks found matching the given filters."
@pytest.mark.asyncio
async def test_list_tasks_with_status_filter(self) -> None:
from app.agents.task_agent import list_tasks
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
await list_tasks.ainvoke({"status": "done"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["filters"]["status"] == "done"
@pytest.mark.asyncio
async def test_create_task_defaults(self) -> None:
from app.agents.task_agent import create_task
fake_row = {"id": "t1", "title": "Test task", "status": "todo", "priority": "medium"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await create_task.ainvoke({"title": "Test task"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "insert"
assert call_kwargs["table"] == "tasks"
assert call_kwargs["data"]["title"] == "Test task"
assert call_kwargs["data"]["status"] == "todo"
assert call_kwargs["data"]["priority"] == "medium"
assert "Test task" in result
@pytest.mark.asyncio
async def test_create_task_with_all_fields(self) -> None:
from app.agents.task_agent import create_task
fake_row = {"id": "t1", "title": "Deploy", "status": "in_progress", "priority": "high"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await create_task.ainvoke({
"title": "Deploy", "priority": "high", "status": "in_progress",
"project_id": "p1", "is_ai_suggested": 1,
})
call_kwargs = m.call_args.kwargs
assert call_kwargs["data"]["priority"] == "high"
assert call_kwargs["data"]["status"] == "in_progress"
assert call_kwargs["data"]["projectId"] == "p1"
assert call_kwargs["data"]["isAiSuggested"] == 1
@pytest.mark.asyncio
async def test_update_task_with_status(self) -> None:
from app.agents.task_agent import update_task
fake_row = {"id": "t1", "title": "Buy groceries", "status": "done"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "update"
assert call_kwargs["data"]["id"] == "t1"
assert call_kwargs["data"]["updates"]["status"] == "done"
assert "t1" in result
@pytest.mark.asyncio
async def test_update_task_empty_updates(self) -> None:
from app.agents.task_agent import update_task
fake_row = {"id": "t1", "title": "Task", "status": "todo"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await update_task.ainvoke({"task_id": "t1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_task(self) -> None:
from app.agents.task_agent import delete_task
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_task.ainvoke({"task_id": "t1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["table"] == "tasks"
assert call_kwargs["data"]["id"] == "t1"
assert "t1" in result
@pytest.mark.asyncio
async def test_list_tasks_due_today(self) -> None:
from app.agents.task_agent import list_tasks_due_today
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_tasks_due_today.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "tasks"
assert "dueDateFrom" in call_kwargs["filters"]
assert result == "No tasks are due today."
@pytest.mark.asyncio
async def test_list_task_comments(self) -> None:
from app.agents.task_agent import list_task_comments
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_task_comments.ainvoke({"task_id": "t1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "taskComments"
assert call_kwargs["filters"]["taskId"] == "t1"
assert "t1" in result
@pytest.mark.asyncio
async def test_add_task_comment(self) -> None:
from app.agents.task_agent import add_task_comment
fake_row = {"id": "c1", "taskId": "t1", "author": "Alice", "content": "Looks good!"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await add_task_comment.ainvoke({
"task_id": "t1", "author": "Alice", "content": "Looks good!",
})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "insert"
assert call_kwargs["table"] == "taskComments"
assert call_kwargs["data"]["taskId"] == "t1"
assert call_kwargs["data"]["author"] == "Alice"
assert call_kwargs["data"]["content"] == "Looks good!"
assert "Alice" in result
@pytest.mark.asyncio
async def test_delete_task_comment(self) -> None:
from app.agents.task_agent import delete_task_comment
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["table"] == "taskComments"
assert call_kwargs["data"]["id"] == "c1"
assert "c1" in result
# ── TimelineAgent ───────────────────────────────────────────────────
class TestTimelineAgent:
def test_name(self) -> None:
assert TimelineAgent().get_name() == "timeline_agent"
def test_description(self) -> None:
assert TimelineAgent().get_description() == "Manages project timelines (milestones): list, create, update, delete"
def test_get_tools_count(self) -> None:
assert len(TimelineAgent().get_tools()) == 4
def test_tool_names(self) -> None:
names = {t.name for t in TimelineAgent().get_tools()}
assert names == {"list_timelines", "create_timeline", "update_timeline", "delete_timeline"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("No timelines found.")
result = await TimelineAgent().handle("list timelines", {})
assert result == "No timelines found."
@pytest.mark.asyncio
async def test_handle_with_create_tool_call(self) -> None:
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_timeline",
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
"Timeline 'MVP Launch' created.",
)
result = await TimelineAgent().handle("add MVP timeline", {})
assert result == "Timeline 'MVP Launch' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await TimelineAgent().handle("show milestones", {})
assert isinstance(result, str)
class TestTimelineAgentTools:
@pytest.mark.asyncio
async def test_list_timelines_no_project(self) -> None:
from app.agents.timeline_agent import list_timelines
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_timelines.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "timelines"
assert call_kwargs["filters"]["projectId"] is None
assert result == "No timelines found."
@pytest.mark.asyncio
async def test_list_timelines_with_project(self) -> None:
from app.agents.timeline_agent import list_timelines
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
await list_timelines.ainvoke({"project_id": "p1"})
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_create_timeline(self) -> None:
from app.agents.timeline_agent import create_timeline
fake_row = {"id": "cp1", "title": "Beta release", "date": 1700000000000}
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await create_timeline.ainvoke({
"project_id": "p1", "title": "Beta release", "date": 1700000000000,
})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "insert"
assert call_kwargs["table"] == "timelines"
assert call_kwargs["data"]["projectId"] == "p1"
assert call_kwargs["data"]["title"] == "Beta release"
assert call_kwargs["data"]["date"] == 1700000000000
assert "Beta release" in result
@pytest.mark.asyncio
async def test_create_timeline_ai_suggested(self) -> None:
from app.agents.timeline_agent import create_timeline
fake_row = {"id": "cp1", "title": "Review", "date": 1700000000000}
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await create_timeline.ainvoke({
"project_id": "p1", "title": "Review", "date": 1700000000000, "is_ai_suggested": 1,
})
call_kwargs = m.call_args.kwargs
assert call_kwargs["data"]["isAiSuggested"] == 1
assert call_kwargs["data"]["isApproved"] == 0
@pytest.mark.asyncio
async def test_update_timeline_approve(self) -> None:
from app.agents.timeline_agent import update_timeline
fake_row = {"id": "c1", "title": "MVP", "isApproved": 1}
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await update_timeline.ainvoke({"timeline_id": "c1", "is_approved": 1})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "update"
assert call_kwargs["data"]["id"] == "c1"
assert call_kwargs["data"]["updates"]["isApproved"] == 1
assert "c1" in result
@pytest.mark.asyncio
async def test_update_timeline_empty_updates(self) -> None:
from app.agents.timeline_agent import update_timeline
fake_row = {"id": "c1", "title": "MVP"}
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await update_timeline.ainvoke({"timeline_id": "c1"})
assert m.call_args.kwargs["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_timeline(self) -> None:
from app.agents.timeline_agent import delete_timeline
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_timeline.ainvoke({"timeline_id": "c1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["table"] == "timelines"
assert call_kwargs["data"]["id"] == "c1"
assert "c1" in result
# ── ProjectAgent ──────────────────────────────────────────────────────
class TestProjectAgent:
def test_name(self) -> None:
assert ProjectAgent().get_name() == "project_agent"
def test_description(self) -> None:
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
def test_get_tools_count(self) -> None:
assert len(ProjectAgent().get_tools()) == 6
def test_tool_names(self) -> None:
names = {t.name for t in ProjectAgent().get_tools()}
assert names == {
"list_projects",
"list_all_projects",
"get_project",
"create_project",
"update_project",
"delete_project",
}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Project Alpha is active.")
result = await ProjectAgent().handle("show my projects", {})
assert result == "Project Alpha is active."
@pytest.mark.asyncio
async def test_handle_with_create_project_tool_call(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_project",
{"name": "Pippo"},
"Project 'Pippo' created.",
)
result = await ProjectAgent().handle("create project Pippo", {})
assert result == "Project 'Pippo' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await ProjectAgent().handle("archive old project", {})
assert isinstance(result, str)
class TestProjectAgentTools:
@pytest.mark.asyncio
async def test_list_projects_defaults(self) -> None:
from app.agents.project_agent import list_projects
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_projects.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "projects"
assert call_kwargs["filters"]["includeArchived"] is False
assert result == "No projects found."
@pytest.mark.asyncio
async def test_list_projects_include_archived(self) -> None:
from app.agents.project_agent import list_projects
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
await list_projects.ainvoke({"include_archived": 1})
assert m.call_args.kwargs["filters"]["includeArchived"] is True
@pytest.mark.asyncio
async def test_list_all_projects(self) -> None:
from app.agents.project_agent import list_all_projects
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_all_projects.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "projects"
assert result == "No projects found."
@pytest.mark.asyncio
async def test_get_project(self) -> None:
from app.agents.project_agent import get_project
fake_row = {"id": "p1", "name": "Alpha", "status": "active", "clientId": None}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await get_project.ainvoke({"project_id": "p1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "get"
assert call_kwargs["table"] == "projects"
assert call_kwargs["data"]["id"] == "p1"
assert "Alpha" in result
@pytest.mark.asyncio
async def test_create_project_name_only(self) -> None:
from app.agents.project_agent import create_project
fake_row = {"id": "p1", "name": "Alpha"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await create_project.ainvoke({"name": "Alpha"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "insert"
assert call_kwargs["data"]["name"] == "Alpha"
assert call_kwargs["data"]["clientId"] is None
assert "Alpha" in result
@pytest.mark.asyncio
async def test_create_project_with_client(self) -> None:
from app.agents.project_agent import create_project
fake_row = {"id": "p1", "name": "Beta"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
assert m.call_args.kwargs["data"]["clientId"] == "cl1"
@pytest.mark.asyncio
async def test_update_project_archive(self) -> None:
from app.agents.project_agent import update_project
fake_row = {"id": "p1", "name": "Alpha", "status": "archived"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "update"
assert call_kwargs["data"]["id"] == "p1"
assert call_kwargs["data"]["updates"]["status"] == "archived"
assert "p1" in result
@pytest.mark.asyncio
async def test_update_project_empty_updates(self) -> None:
from app.agents.project_agent import update_project
fake_row = {"id": "p1", "name": "Alpha", "status": "active"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await update_project.ainvoke({"project_id": "p1"})
assert m.call_args.kwargs["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_project(self) -> None:
from app.agents.project_agent import delete_project
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_project.ainvoke({"project_id": "p1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["data"]["id"] == "p1"
assert "p1" in result
# ── NoteAgent ─────────────────────────────────────────────────────────
class TestNoteAgent:
def test_name(self) -> None:
assert NoteAgent().get_name() == "note_agent"
def test_description(self) -> None:
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
def test_get_tools_count(self) -> None:
assert len(NoteAgent().get_tools()) == 5
def test_tool_names(self) -> None:
names = {t.name for t in NoteAgent().get_tools()}
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Note created.")
result = await NoteAgent().handle("create a note", {})
assert result == "Note created."
@pytest.mark.asyncio
async def test_handle_with_create_note_tool_call(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_note",
{"title": "Daily log", "content": "# Today\nAll good."},
"Note 'Daily log' created.",
)
result = await NoteAgent().handle("log today's progress", {})
assert result == "Note 'Daily log' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await NoteAgent().handle("show notes", {})
assert isinstance(result, str)
class TestNoteAgentTools:
@pytest.mark.asyncio
async def test_list_notes_no_project(self) -> None:
from app.agents.note_agent import list_notes
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_notes.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "notes"
assert call_kwargs["filters"]["projectId"] is None
assert result == "No notes found."
@pytest.mark.asyncio
async def test_list_notes_with_project(self) -> None:
from app.agents.note_agent import list_notes
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
await list_notes.ainvoke({"project_id": "p1"})
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_get_note(self) -> None:
from app.agents.note_agent import get_note
fake_row = {"id": "n1", "title": "Daily log", "content": "# Today\nAll good."}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await get_note.ainvoke({"note_id": "n1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "get"
assert call_kwargs["table"] == "notes"
assert call_kwargs["data"]["id"] == "n1"
assert "Daily log" in result
@pytest.mark.asyncio
async def test_create_note_minimal(self) -> None:
from app.agents.note_agent import create_note
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
m.return_value = {"row": fake_row}
me.return_value = [0.0] * 1536
result = await create_note.ainvoke({"title": "Daily log", "content": "# Today\nAll good."})
# First call: insert; second call: vector_upsert
first_call = m.call_args_list[0].kwargs
assert first_call["action"] == "insert"
assert first_call["table"] == "notes"
assert first_call["data"]["title"] == "Daily log"
assert first_call["data"]["content"] == "# Today\nAll good."
assert first_call["data"]["projectId"] is None
assert "Daily log" in result
@pytest.mark.asyncio
async def test_create_note_with_project(self) -> None:
from app.agents.note_agent import create_note
fake_row = {"id": "n1", "title": "Sprint notes", "projectId": "p1"}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
m.return_value = {"row": fake_row}
me.return_value = [0.0] * 1536
await create_note.ainvoke({"title": "Sprint notes", "content": "## Sprint 1", "project_id": "p1"})
first_call = m.call_args_list[0].kwargs
assert first_call["data"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_update_note_content_only(self) -> None:
from app.agents.note_agent import update_note
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
m.return_value = {"row": fake_row}
me.return_value = [0.0] * 1536
result = await update_note.ainvoke({"note_id": "n1", "content": "# Updated content"})
first_call = m.call_args_list[0].kwargs
assert first_call["action"] == "update"
assert first_call["data"]["id"] == "n1"
assert first_call["data"]["updates"]["content"] == "# Updated content"
assert "title" not in first_call["data"]["updates"]
assert "n1" in result
@pytest.mark.asyncio
async def test_update_note_empty_updates(self) -> None:
from app.agents.note_agent import update_note
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await update_note.ainvoke({"note_id": "n1"})
assert m.call_args.kwargs["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_note(self) -> None:
from app.agents.note_agent import delete_note
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_note.ainvoke({"note_id": "n1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["table"] == "notes"
assert call_kwargs["data"]["id"] == "n1"
assert "n1" in result

View File

@@ -1,286 +0,0 @@
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
from __future__ import annotations
import pytest
from app.core.execution_plan import (
ExecutionPlanBuilder,
PlanCache,
PromptTemplateRegistry,
plan_cache,
template_registry,
)
from app.schemas import ExecutionPlan
# ── PromptTemplateRegistry ────────────────────────────────────────────
class TestPromptTemplateRegistry:
def test_register_and_get(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_foo", "You are a foo agent.")
assert reg.get("tpl_foo") == "You are a foo agent."
def test_get_unknown_raises_key_error(self) -> None:
reg = PromptTemplateRegistry()
with pytest.raises(KeyError, match="tpl_missing"):
reg.get("tpl_missing")
def test_has_returns_true_for_registered(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_x", "prompt text")
assert reg.has("tpl_x") is True
def test_has_returns_false_for_unregistered(self) -> None:
reg = PromptTemplateRegistry()
assert reg.has("tpl_missing") is False
def test_list_ids_returns_all_registered_ids(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_a", "a")
reg.register("tpl_b", "b")
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
def test_list_ids_does_not_return_prompt_text(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_secret", "top secret prompt")
ids = reg.list_ids()
assert "top secret prompt" not in ids
def test_overwrite_existing_template(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_x", "v1")
reg.register("tpl_x", "v2")
assert reg.get("tpl_x") == "v2"
def test_empty_registry_has_no_ids(self) -> None:
reg = PromptTemplateRegistry()
assert reg.list_ids() == []
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
class TestExecutionPlanBuilder:
def test_builds_empty_plan(self) -> None:
plan = ExecutionPlanBuilder("task_agent").build()
assert plan.agent == "task_agent"
assert plan.steps == []
def test_add_step_basic(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("create_task", {"priority": "high"})
.build()
)
assert len(plan.steps) == 1
assert plan.steps[0].action == "create_task"
assert plan.steps[0].variables == {"priority": "high"}
assert plan.steps[0].prompt_template is None
assert plan.steps[0].data_from_step is None
def test_add_step_no_params(self) -> None:
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
assert plan.steps[0].variables is None
def test_add_llm_step(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_llm_step("tpl_task_default", {"message": "hi"})
.build()
)
assert plan.steps[0].action == "llm"
assert plan.steps[0].prompt_template == "tpl_task_default"
assert plan.steps[0].variables == {"message": "hi"}
def test_add_llm_step_no_variables(self) -> None:
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
assert plan.steps[0].variables is None
def test_add_data_step(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("fetch_data")
.add_data_step("transform", data_from_step=0)
.build()
)
assert plan.steps[1].action == "transform"
assert plan.steps[1].data_from_step == 0
def test_fluent_chaining_returns_builder(self) -> None:
builder = ExecutionPlanBuilder("analytics_agent")
result = builder.add_step("a")
assert result is builder
def test_fluent_chain_multiple_steps(self) -> None:
plan = (
ExecutionPlanBuilder("analytics_agent")
.add_llm_step("tpl_analytics_default")
.add_step("format_output")
.add_data_step("store", data_from_step=0)
.build()
)
assert len(plan.steps) == 3
def test_build_validates_data_from_step_out_of_range(self) -> None:
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
def test_build_validates_data_from_step_self_reference(self) -> None:
"""data_from_step=0 on the first step (index 0) is invalid."""
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
def test_build_validates_data_from_step_negative(self) -> None:
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
def test_valid_data_from_step_at_index_two(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("step0")
.add_step("step1")
.add_data_step("step2", data_from_step=1)
.build()
)
assert plan.steps[2].data_from_step == 1
def test_data_from_step_zero_valid_at_index_one(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("step0")
.add_data_step("step1", data_from_step=0)
.build()
)
assert plan.steps[1].data_from_step == 0
def test_build_returns_new_plan_each_call(self) -> None:
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
plan1 = builder.build()
plan2 = builder.build()
assert plan1 is not plan2
assert plan1.steps == plan2.steps
def test_plan_is_execution_plan_instance(self) -> None:
plan = ExecutionPlanBuilder("task_agent").build()
assert isinstance(plan, ExecutionPlan)
# ── PlanCache ─────────────────────────────────────────────────────────
class TestPlanCache:
def _plan(self, agent: str = "a") -> ExecutionPlan:
return ExecutionPlanBuilder(agent).build()
def test_cache_and_get(self) -> None:
cache = PlanCache()
plan = self._plan()
cache.cache_plan("key1", plan)
assert cache.get_plan("key1") is plan
def test_get_missing_returns_none(self) -> None:
cache = PlanCache()
assert cache.get_plan("nonexistent") is None
def test_get_all_playbooks_empty(self) -> None:
cache = PlanCache()
assert cache.get_all_playbooks() == []
def test_get_all_playbooks_returns_all_stored(self) -> None:
cache = PlanCache()
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
playbooks = cache.get_all_playbooks()
assert len(playbooks) == 2
assert p1 in playbooks
assert p2 in playbooks
def test_lru_evicts_oldest_entry(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
cache.cache_plan("k3", p3) # k1 should be evicted
assert cache.get_plan("k1") is None
assert cache.get_plan("k2") is p2
assert cache.get_plan("k3") is p3
def test_lru_access_updates_recency(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
cache.get_plan("k1") # k1 is now most-recently used
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
assert cache.get_plan("k1") is p1
assert cache.get_plan("k2") is None
assert cache.get_plan("k3") is p3
def test_overwrite_existing_key(self) -> None:
cache = PlanCache()
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("same_key", p1)
cache.cache_plan("same_key", p2)
assert cache.get_plan("same_key") is p2
assert len(cache.get_all_playbooks()) == 1
def test_overwrite_does_not_consume_capacity(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("k1", p1)
cache.cache_plan("k1", p2) # overwrite, not a new slot
cache.cache_plan("k2", p1) # should fit without eviction
assert cache.get_plan("k1") is p2
assert cache.get_plan("k2") is p1
# ── Module-level singletons ───────────────────────────────────────────
class TestModuleSingletons:
def test_template_registry_has_all_agent_defaults(self) -> None:
for agent in ("task_agent", "timeline_agent", "project_agent", "note_agent"):
assert template_registry.has(f"tpl_{agent}_default"), (
f"Missing template: tpl_{agent}_default"
)
def test_template_registry_has_operation_templates(self) -> None:
assert template_registry.has("tpl_task_extract_from_project")
assert template_registry.has("tpl_note_weekly_summary")
def test_template_registry_get_returns_non_empty_string(self) -> None:
text = template_registry.get("tpl_task_agent_default")
assert isinstance(text, str)
assert len(text) > 0
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
assert len(plan_cache.get_all_playbooks()) >= 2
def test_playbook_create_tasks_from_project(self) -> None:
plan = plan_cache.get_plan("create_tasks_from_project")
assert plan is not None
assert plan.agent == "project_agent"
assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
assert plan.steps[1].data_from_step == 0
def test_playbook_generate_weekly_note(self) -> None:
plan = plan_cache.get_plan("generate_weekly_note")
assert plan is not None
assert plan.agent == "note_agent"
assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
assert plan.steps[1].data_from_step == 0
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
"""Plans must not embed prompt text — only template IDs."""
for plan in plan_cache.get_all_playbooks():
for step in plan.steps:
if step.prompt_template is not None:
assert step.prompt_template.startswith("tpl_"), (
f"prompt_template looks like raw text: {step.prompt_template!r}"
)

View File

@@ -250,15 +250,15 @@ def test_home_request_calls_memory_middleware(client):
token = make_jwt("power", user_id=USER_ID)
session_id = str(uuid.uuid4())
async def _mock_stream(user_id, message, context, reg=None):
async def _mock_stream(user_id, message, context, db_session_factory=None):
# Verify memory context was injected
assert context.get("core_memory") == {"tz": "UTC"}
yield "task_agent", ""
yield "task_agent", '{"type": "text", "content": "Done"}'
yield ("token", "Done")
yield ("mutations", [])
with (
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_stream),
patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({

View File

@@ -20,7 +20,6 @@ from jose import jwt
from app.config.settings import settings
from app.db import get_session
from app.main import app
from app.schemas import ChatResponse
from tests.conftest import TEST_USER_IDS
# ---------------------------------------------------------------------------
@@ -50,7 +49,6 @@ _CHAT_BODY = {
"recent_tasks": [],
"conversation_history": [],
},
"execution_mode": "direct",
}
@@ -240,7 +238,7 @@ class TestRateLimitMiddleware:
class TestSanitizerMiddleware:
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
"""Mock ``run_home`` to inject controlled strings into chat responses."""
_CHAT_PATH = "/api/v1/chat"
@@ -248,11 +246,10 @@ class TestSanitizerMiddleware:
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
def _post_chat(self, client: TestClient, response_text: str) -> dict:
mock_response = ChatResponse(response=response_text, actions=[])
with patch(
"app.api.routes.chat.orchestrate",
"app.api.routes.chat.run_home",
new_callable=AsyncMock,
return_value=mock_response,
return_value=response_text,
):
resp = client.post(
self._CHAT_PATH,

View File

@@ -1,347 +0,0 @@
"""Integration tests for the orchestrator module."""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.core.agent_registry import AgentRegistry, ChatAgent
from app.core.orchestrator import (
classify_intent,
orchestrate,
orchestrate_stream,
route_pipeline,
route_single,
)
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
# ── Stub agents ──────────────────────────────────────────────────────
class _TaskAgent(ChatAgent):
def get_name(self) -> str:
return "task_agent"
def get_description(self) -> str:
return "Manages tasks: create, update, list, suggest"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"task: {query}"
class _CalendarAgent(ChatAgent):
def get_name(self) -> str:
return "calendar_agent"
def get_description(self) -> str:
return "Calendar management: events, conflicts, scheduling"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"calendar: {query}"
# ── Helpers ──────────────────────────────────────────────────────────
def _mock_llm(response_text: str) -> MagicMock:
"""Return a mock LLM that always produces *response_text*."""
msg = MagicMock()
msg.content = response_text
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=msg)
return llm
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _fresh_registry():
"""Reset the AgentRegistry singleton between tests."""
AgentRegistry._instance = None
yield
AgentRegistry._instance = None
@pytest.fixture()
def reg() -> AgentRegistry:
r = AgentRegistry()
r.register(_TaskAgent)
r.register(_CalendarAgent)
return r
# ── classify_intent ───────────────────────────────────────────────────
class TestClassifyIntent:
@pytest.mark.asyncio
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
result = await classify_intent("add a task", {}, reg)
assert result == "task_agent"
@pytest.mark.asyncio
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("calendar_agent")
result = await classify_intent("schedule a meeting", {}, reg)
assert result == "calendar_agent"
@pytest.mark.asyncio
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("nonexistent_agent")
result = await classify_intent("do something", {}, reg)
assert result == "task_agent"
@pytest.mark.asyncio
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
empty_reg = AgentRegistry()
# No LLM should be instantiated — early return path
with patch("app.core.orchestrator._make_llm") as mock_cls:
result = await classify_intent("anything", {}, empty_reg)
mock_cls.assert_not_called()
assert result == "task_agent"
@pytest.mark.asyncio
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm(" task_agent \n")
result = await classify_intent("create task", {}, reg)
assert result == "task_agent"
# ── route_single ─────────────────────────────────────────────────────
class TestRouteSingle:
@pytest.mark.asyncio
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "create a task", {}, reg)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "create a task", {}, reg)
assert result.response == "task: create a task"
@pytest.mark.asyncio
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError):
await route_single("nonexistent", "hello", {}, reg)
@pytest.mark.asyncio
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "hi", {}, reg)
assert result.actions == []
# ── route_pipeline ────────────────────────────────────────────────────
class TestRoutePipeline:
@pytest.mark.asyncio
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("synthesized result")
result = await route_pipeline(
["task_agent", "calendar_agent"], "plan my week", {}, reg
)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("synthesized result")
result = await route_pipeline(
["task_agent", "calendar_agent"], "plan my week", {}, reg
)
assert result.response == "synthesized result"
@pytest.mark.asyncio
async def test_passes_previous_results_to_subsequent_agents(
self, reg: AgentRegistry
) -> None:
"""Each agent after the first should receive prior outputs in context."""
received_contexts: list[dict[str, Any]] = []
class _CapturingAgent(ChatAgent):
def get_name(self) -> str:
return "capture"
def get_description(self) -> str:
return "captures context for testing"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
received_contexts.append(dict(context))
return "captured"
reg.register(_CapturingAgent)
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("done")
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
# The second agent (capture) must have received previous results
assert len(received_contexts) == 1
assert "previous_results" in received_contexts[0]
assert received_contexts[0]["previous_results"] == ["task: hi"]
@pytest.mark.asyncio
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("single result")
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
assert result.response == "single result"
# ── orchestrate ───────────────────────────────────────────────────────
class TestOrchestrate:
@pytest.mark.asyncio
async def test_direct_mode_returns_chat_response(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
assert result.response == "task: add a task"
@pytest.mark.asyncio
async def test_plan_mode_returns_execution_plan(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan my tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
@pytest.mark.asyncio
async def test_plan_mode_agent_matches_classified(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("calendar_agent")
request = ChatRequest(
message="schedule something", execution_mode="plan"
)
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert result.agent == "calendar_agent"
@pytest.mark.asyncio
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert len(result.steps) >= 1
@pytest.mark.asyncio
async def test_plan_mode_template_id_contains_agent_name(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert result.steps[0].prompt_template is not None
assert "task_agent" in result.steps[0].prompt_template
@pytest.mark.asyncio
async def test_default_execution_mode_is_direct(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
# execution_mode defaults to "direct"
request = ChatRequest(message="help me")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
# ── orchestrate_stream ────────────────────────────────────────────────
class TestOrchestrateStream:
@pytest.mark.asyncio
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
assert len(chunks) >= 1
@pytest.mark.asyncio
async def test_all_chunks_are_plain_text(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
# orchestrate_stream yields plain text chunks only — no JSON final frame
for chunk in chunks:
assert isinstance(chunk, str)
@pytest.mark.asyncio
async def test_concatenated_chunks_equal_full_response(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="create a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
full_text = "".join(chunks)
assert full_text == "task: create a task"
@pytest.mark.asyncio
async def test_text_chunks_before_final_frame(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(
message="x" * 200, execution_mode="direct"
) # long enough to produce multiple chunks
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
# All but the last chunk should be plain text (not valid final JSON)
non_final = chunks[:-1]
for chunk in non_final:
try:
parsed = json.loads(chunk)
assert parsed.get("done") is not True
except json.JSONDecodeError:
pass # plain text chunk — expected

View File

@@ -1,236 +0,0 @@
"""Tests for v3 orchestrator functions (Step 3)."""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any
from app.core.agent_registry import ChatAgent, AgentRegistry
from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream
# ── Minimal agent for testing ─────────────────────────────────────────
class _FixedAgent(ChatAgent):
def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._name = name
self._tokens = tokens or ["Hello", " world"]
def get_name(self) -> str:
return self._name
def get_description(self) -> str:
return "Fixed agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "".join(self._tokens)
async def handle_stream(self, query: str, context: dict[str, Any]):
for tok in self._tokens:
yield tok
# ── Mock registry factory ─────────────────────────────────────────────
def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock:
reg = MagicMock(spec=AgentRegistry)
reg.list_agents.return_value = [{"name": agent_name, "description": "test"}]
reg.get.return_value = agent
return reg
# ── orchestrate_v3 ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_orchestrate_v3_returns_agent_name_and_instance():
agent = _FixedAgent("task_agent")
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
name, inst = await orchestrate_v3(
user_id="u-1", message="fix a bug", context={}, reg=reg
)
assert name == "task_agent"
assert inst is agent
@pytest.mark.asyncio
async def test_orchestrate_v3_classify_called_with_message_and_context():
agent = _FixedAgent("note_agent")
reg = _make_registry("note_agent", agent)
ctx = {"some": "context"}
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify:
await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg)
mock_classify.assert_awaited_once()
call_args = mock_classify.call_args
assert call_args[0][0] == "take a note"
assert call_args[0][1] == ctx
@pytest.mark.asyncio
async def test_orchestrate_v3_uses_default_registry_when_none():
agent = _FixedAgent("task_agent")
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
patch("app.core.orchestrator._default_registry") as mock_reg:
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
mock_reg.get.return_value = agent
name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={})
assert name == "task_agent"
assert inst is agent
@pytest.mark.asyncio
async def test_orchestrate_v3_get_called_with_agent_name():
agent = _FixedAgent("timeline_agent")
reg = _make_registry("timeline_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="timeline_agent")):
await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg)
reg.get.assert_called_once_with("timeline_agent")
# ── orchestrate_v3_stream ─────────────────────────────────────────────
async def _collect(gen) -> list[tuple[str, str]]:
results: list[tuple[str, str]] = []
async for item in gen:
results.append(item)
return results
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_first_yield_is_domain_signal():
agent = _FixedAgent("task_agent", tokens=["token1"])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
# First item must be (agent_name, "") — domain signal
assert results[0] == ("task_agent", "")
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_yields_agent_name_with_tokens():
agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
# All items are (agent_name, token) pairs
assert all(name == "task_agent" for name, _ in results)
tokens = [tok for _, tok in results]
assert tokens[0] == "" # domain signal
assert tokens[1:] == ["Hello", " ", "world"]
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_different_agent():
agent = _FixedAgent("note_agent", tokens=["note"])
reg = _make_registry("note_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")):
gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg)
results = await _collect(gen)
assert results[0] == ("note_agent", "")
assert ("note_agent", "note") in results
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_uses_default_registry_when_none():
agent = _FixedAgent("task_agent", tokens=["x"])
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
patch("app.core.orchestrator._default_registry") as mock_reg:
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
mock_reg.get.return_value = agent
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={})
results = await _collect(gen)
assert results[0][0] == "task_agent"
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_empty_token_list():
"""Agent with no tokens still emits the domain signal."""
class _EmptyAgent(_FixedAgent):
async def handle_stream(self, query: str, context: dict[str, Any]):
return
yield # makes it a generator
agent = _EmptyAgent("task_agent", tokens=[])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
assert results == [("task_agent", "")] # only domain signal
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_full_text_correct():
"""Concatenating all non-domain tokens reconstructs the full response."""
tokens = ["The", " ", "task", " ", "is", " ", "done."]
agent = _FixedAgent("task_agent", tokens=tokens)
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
text = "".join(tok for _, tok in results[1:]) # skip domain signal
assert text == "The task is done."
# ── handle_stream default implementation ─────────────────────────────
@pytest.mark.asyncio
async def test_handle_stream_default_yields_full_response():
"""Default handle_stream yields handle() result as a single chunk."""
class _SimpleAgent(ChatAgent):
def get_name(self) -> str:
return "_simple"
def get_description(self) -> str:
return ""
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "simple response"
agent = _SimpleAgent()
tokens = [tok async for tok in agent.handle_stream("q", {})]
assert tokens == ["simple response"]
@pytest.mark.asyncio
async def test_handle_stream_override_used_by_stream():
"""_FixedAgent.handle_stream override yields individual tokens."""
agent = _FixedAgent("t", tokens=["a", "b", "c"])
tokens = [tok async for tok in agent.handle_stream("q", {})]
assert tokens == ["a", "b", "c"]

View File

@@ -7,7 +7,6 @@ import pytest
from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.schemas import (
WsFloatingDomain,
WsStreamBlock,
WsStreamEnd,
WsStreamStart,
WsStreamText,
@@ -16,15 +15,15 @@ from app.schemas import (
# ── helpers ───────────────────────────────────────────────────────────────────
async def _stream(*pairs: tuple[str, str]):
"""Async generator that yields (agent_name, token) pairs."""
for pair in pairs:
yield pair
async def _stream(*events: tuple[str, object]):
"""Async generator that yields (event_type, data) tuples."""
for event in events:
yield event
async def collect(formatter, token_stream):
async def collect(formatter, event_stream):
frames = []
async for frame in formatter.format(token_stream):
async for frame in formatter.format(event_stream):
frames.append(frame)
return frames
@@ -32,13 +31,14 @@ async def collect(formatter, token_stream):
# ── HomeFormatter ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_home_formatter_text_block():
async def test_home_formatter_plain_text():
req_id = "req-1"
tokens = [
("task_agent", '{"type": "text", "content": "Hello world"}'),
events = [
("token", "Hello world"),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(*tokens))
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsStreamStart)
assert frames[0].request_id == req_id
@@ -48,104 +48,92 @@ async def test_home_formatter_text_block():
@pytest.mark.asyncio
async def test_home_formatter_chart_block():
async def test_home_formatter_entity_tags_passed_through():
"""Entity tags are streamed as-is — the frontend parses them."""
req_id = "req-2"
chart_json = (
'{"type": "chart", "chartType": "bar", '
'"title": "Tasks", "data": [{"x": 1}], '
'"config": {"x": {"label": "X", "color": "#fff"}}}'
)
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", chart_json)))
events = [
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "chart"
assert block_frames[0].data["chartType"] == "bar"
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
assert "<project>[abc-123]</project>" in text
assert "Here is your project:" in text
assert "All good." in text
@pytest.mark.asyncio
async def test_home_formatter_invalid_chart_skipped():
async def test_home_formatter_multiple_tags_passed_through():
req_id = "req-3"
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", bad_chart)))
events = [
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 0 # invalid chart skipped
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
assert "<project>[p1]</project>" in text
assert "<task>[t1,t2]</task>" in text
@pytest.mark.asyncio
async def test_home_formatter_entity_ref_resolved():
async def test_home_formatter_tool_end_ignored():
"""tool_end events are silently ignored by HomeFormatter."""
req_id = "req-4"
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
entity_json = '{"type": "entity_ref", "entity": "task"}'
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
frames = await collect(formatter, _stream(("task_agent", entity_json)))
events = [
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
("token", "No tags here."),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].data["entity"] == "task"
assert block_frames[0].data["items"][0]["id"] == "t1"
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
assert text == "No tags here."
@pytest.mark.asyncio
async def test_home_formatter_entity_ref_missing_skipped():
async def test_home_formatter_mutations_in_stream_end():
req_id = "req-5"
entity_json = '{"type": "entity_ref", "entity": "task"}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", entity_json)))
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
events = [
("token", "Done"),
("mutations", muts),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 0 # no tool results → skipped
@pytest.mark.asyncio
async def test_home_formatter_table_block():
req_id = "req-6"
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", table_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "table"
@pytest.mark.asyncio
async def test_home_formatter_timeline_block():
req_id = "req-7"
timeline_json = '{"type": "timeline", "timelines": [{"id": "c1", "title": "M1", "date": 123}]}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "timeline"
end_frame = frames[-1]
assert isinstance(end_frame, WsStreamEnd)
assert len(end_frame.mutations) == 1
assert end_frame.mutations[0]["action"] == "insert"
@pytest.mark.asyncio
async def test_home_formatter_frame_order():
"""stream_start is first, stream_end is last."""
req_id = "req-8"
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
req_id = "req-6"
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
assert isinstance(frames[0], WsStreamStart)
assert isinstance(frames[-1], WsStreamEnd)
# ── FloatingFormatter ────────────────────────────────────────────────────────────
# ── FloatingFormatter ─────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_floating_formatter_domain_emitted_first():
async def test_floating_formatter_domain_from_tool_end():
req_id = "pop-1"
formatter = FloatingFormatter(request_id=req_id)
tokens = [
("task_agent", ""), # domain signal
("task_agent", "Hello"),
("task_agent", " there"),
events = [
("tool_end", {"name": "task_agent", "result": "ok"}),
("token", "Hello"),
("mutations", []),
]
frames = await collect(formatter, _stream(*tokens))
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "tasks"
@@ -156,8 +144,12 @@ async def test_floating_formatter_domain_emitted_first():
async def test_floating_formatter_text_only():
req_id = "pop-2"
formatter = FloatingFormatter(request_id=req_id)
tokens = [("timeline_agent", ""), ("timeline_agent", "Summary")]
frames = await collect(formatter, _stream(*tokens))
events = [
("tool_end", {"name": "timeline_agent", "result": "done"}),
("token", "Summary"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "timelines"
@@ -167,29 +159,56 @@ async def test_floating_formatter_text_only():
@pytest.mark.asyncio
async def test_floating_formatter_no_block_frames():
"""FloatingFormatter must never emit WsStreamBlock."""
async def test_floating_formatter_no_entity_tags():
"""FloatingFormatter never emits entity tag blocks."""
req_id = "pop-3"
formatter = FloatingFormatter(request_id=req_id)
tokens = [
("note_agent", ""),
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
events = [
("tool_end", {"name": "note_agent", "result": "data"}),
("token", "some text"),
("mutations", []),
]
frames = await collect(formatter, _stream(*tokens))
assert not any(isinstance(f, WsStreamBlock) for f in frames)
frames = await collect(formatter, _stream(*events))
# Only expected frame types
for f in frames:
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
@pytest.mark.asyncio
async def test_floating_formatter_end_frame():
req_id = "pop-4"
formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
events = [
("tool_end", {"name": "project_agent", "result": "ok"}),
("token", "Done"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_floating_formatter_unknown_agent_defaults_to_tasks():
async def test_floating_formatter_default_domain_on_early_token():
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
req_id = "pop-5"
formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
events = [("token", "hi"), ("mutations", [])]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "tasks"
@pytest.mark.asyncio
async def test_floating_formatter_mutations_in_stream_end():
req_id = "pop-6"
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
events = [
("token", "Updated"),
("mutations", muts),
]
formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
end_frame = frames[-1]
assert isinstance(end_frame, WsStreamEnd)
assert len(end_frame.mutations) == 1

View File

@@ -88,7 +88,7 @@ class TestPluginRegistry:
async def test_list_filter_by_query(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, query="time")
result = await reg.list_plugins(db_session, query="time tracker")
assert result.total == 1
assert result.plugins[0].id == "plugin-time-tracker"

View File

@@ -9,7 +9,6 @@ from app.schemas import (
WsFloatingDomain,
WsFloatingRequest,
WsFloatingScope,
WsStreamBlock,
WsStreamEnd,
WsStreamStart,
WsStreamText,
@@ -25,7 +24,6 @@ def test_v3_frame_types_exist():
"floating_request",
"stream_start",
"stream_text",
"stream_block",
"stream_end",
"floating_domain",
"data_request",
@@ -174,66 +172,6 @@ def test_stream_text_deserializes():
assert frame.chunk == "test"
# ── WsStreamBlock ─────────────────────────────────────────────────────
def test_stream_block_chart():
data = {
"type": "chart",
"chartType": "bar",
"title": "Tasks",
"data": [{"name": "Done", "count": 5}],
"config": {"count": {"label": "Count", "color": "#4f46e5"}},
}
frame = WsStreamBlock(request_id="r1", block_type="chart", data=data)
assert frame.type == WsFrameType.stream_block
assert frame.block_type == "chart"
assert frame.data["chartType"] == "bar"
def test_stream_block_entity_ref():
frame = WsStreamBlock(
request_id="r1",
block_type="entity_ref",
data={"type": "task", "id": "t-1", "title": "Fix bug"},
)
assert frame.block_type == "entity_ref"
def test_stream_block_table():
frame = WsStreamBlock(
request_id="r1",
block_type="table",
data={"headers": ["A", "B"], "rows": [["1", "2"]]},
)
assert frame.block_type == "table"
def test_stream_block_timeline():
frame = WsStreamBlock(
request_id="r1",
block_type="timeline",
data={"timelines": [{"id": "c1", "title": "Launch", "date": 1700000000}]},
)
assert frame.block_type == "timeline"
def test_stream_block_invalid_type():
with pytest.raises(ValidationError):
WsStreamBlock(
request_id="r1",
block_type="unknown", # type: ignore[arg-type]
data={},
)
def test_stream_block_serializes():
frame = WsStreamBlock(request_id="r1", block_type="table", data={"headers": [], "rows": []})
d = frame.model_dump()
assert d["type"] == "stream_block"
assert d["block_type"] == "table"
# ── WsStreamEnd ───────────────────────────────────────────────────────

View File

@@ -45,14 +45,15 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
return frames
async def _mock_home_stream(user_id, message, context, reg=None):
yield "task_agent", ""
yield "task_agent", '{"type": "text", "content": "Hello"}'
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
yield "mutations", []
async def _mock_floating_stream(user_id, message, context, reg=None):
yield "task_agent", ""
yield "task_agent", "Here is a summary"
async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
yield "tool_end", {"name": "task_agent", "result": "ok"}
yield "token", "Here is a summary"
yield "mutations", []
# ── tests ─────────────────────────────────────────────────────────────────────
@@ -61,7 +62,7 @@ def test_home_request_produces_stream_frames(client):
"""home_request → stream_start, stream_text+, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream):
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
@@ -84,7 +85,7 @@ def test_floating_request_produces_domain_frame(client):
"""floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_floating_stream):
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
@@ -112,11 +113,11 @@ def test_home_request_request_id_propagated(client):
token = make_jwt("power", user_id=USER_ID)
req_id = "my-unique-req-id"
async def _stream(user_id, message, context, reg=None):
yield "note_agent", ""
yield "note_agent", '{"type": "text", "content": "ok"}'
async def _stream(user_id, message, context, db_session_factory=None):
yield "token", "ok"
yield "mutations", []
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream):
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-3", "agent_ids": []