Compare commits
28 Commits
24772f2b67
...
feature/de
| Author | SHA1 | Date | |
|---|---|---|---|
| 47bf1881e5 | |||
| 24a9c1b752 | |||
| 706bf88883 | |||
| 4ff0b27084 | |||
| 61d2a18234 | |||
| b3687719b6 | |||
| f80bdfa8f7 | |||
| 617a17db40 | |||
| 92716cb89a | |||
| cfc9d7a942 | |||
| 2de67213f8 | |||
| f6ed383b3a | |||
| 9332e29e53 | |||
| 618076193a | |||
| 34f01234c9 | |||
| 0bd46937d3 | |||
| e6b5bc2e7d | |||
| c90ed58078 | |||
| 76c8f2bdad | |||
| 393b3befd6 | |||
| 2c08275934 | |||
| 7cb384fa63 | |||
| 7efaeba283 | |||
| b61ded8458 | |||
| ac71d99f9a | |||
| 3b3b3baf25 | |||
| 45415bb9ee | |||
| a775a2da18 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -31,3 +31,4 @@ Thumbs.db
|
|||||||
|
|
||||||
# Claude Code
|
# Claude Code
|
||||||
.claude/
|
.claude/
|
||||||
|
logs/
|
||||||
|
|||||||
@@ -1,512 +0,0 @@
|
|||||||
# AI Refactor Plan — Adiuva Backend
|
|
||||||
|
|
||||||
> **Objective:** Transform backend tools from JSON-action-descriptor-returning functions into real bidirectional executors. Each tool sends structured CRUD operations to the Electron client via WebSocket, receives real data back, and returns meaningful results to the LLM. The LLM reasons about actual user data instead of serialized action payloads.
|
|
||||||
>
|
|
||||||
> **Electron app:** Lives at `../adiuva/`. See `../adiuva/AI_REFACTOR_PLAN.md`.
|
|
||||||
>
|
|
||||||
> **Protocol:** Execute steps sequentially. Each step is atomic and committable. Mark `[x]` when done.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture — Before vs After
|
|
||||||
|
|
||||||
### Before (current)
|
|
||||||
```
|
|
||||||
LLM calls list_tasks(status="todo")
|
|
||||||
→ tool returns: '{"action":"list","table":"tasks","filters":{"status":"todo"}}'
|
|
||||||
→ _tool_loop feeds that JSON string as ToolMessage to LLM
|
|
||||||
→ LLM sees a descriptor, NOT real data — cannot reason about tasks
|
|
||||||
→ Final response: generic "Here are your tasks" (no actual task data)
|
|
||||||
→ Action descriptors sent in final WS frame for Electron to execute post-response
|
|
||||||
```
|
|
||||||
|
|
||||||
### After (target)
|
|
||||||
```
|
|
||||||
LLM calls list_tasks(status="todo")
|
|
||||||
→ tool calls execute_on_client(action="select", table="tasks", filters={status:"todo"})
|
|
||||||
→ WS frame sent to Electron: {type:"tool_call", id:"abc", action:"select", table:"tasks", filters:{status:"todo"}}
|
|
||||||
→ Electron runs: db.select().from(tasks).where(eq(tasks.status, "todo")).all()
|
|
||||||
→ WS frame back: {type:"tool_result", id:"abc", rows:[{id:"1",title:"Buy milk",...}, ...]}
|
|
||||||
→ tool returns: "Found 3 tasks: 1. Buy milk (high, due tomorrow) 2. ..."
|
|
||||||
→ _tool_loop feeds that as ToolMessage to LLM
|
|
||||||
→ LLM sees REAL data — can reason, count, compare, summarize
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## WS Protocol — Typed Frames
|
|
||||||
|
|
||||||
| Direction | `type` | Payload |
|
|
||||||
|---|---|---|
|
|
||||||
| Client → Server | `chat_request` | `{ message: str, context: ChatContext }` |
|
|
||||||
| Server → Client | `text_chunk` | `{ text: str }` |
|
|
||||||
| Server → Client | `tool_call` | `{ id: str, action: str, table?: str, data?: dict, filters?: dict, vector?: list[float], limit?: int }` |
|
|
||||||
| Client → Server | `tool_result` | `{ id: str, row?: dict, rows?: list[dict], results?: list[dict], deleted?: bool, ok?: bool, error?: str }` |
|
|
||||||
| Server → Client | `final` | `{ response: str }` |
|
|
||||||
| Server → Client | `ping` | `{}` |
|
|
||||||
|
|
||||||
**Actions:**
|
|
||||||
|
|
||||||
| `action` | What Electron does (Drizzle) | `tool_result` shape |
|
|
||||||
|---|---|---|
|
|
||||||
| `select` | `db.select().from(table).where(filters)` | `{ rows: [...] }` |
|
|
||||||
| `get` | `db.select().from(table).where(id=...).get()` | `{ row: {...} or null }` |
|
|
||||||
| `insert` | `db.insert(table).values({id: uuid(), ...data}).returning().get()` | `{ row: {...} }` |
|
|
||||||
| `update` | `db.update(table).set(updates).where(id=...).returning().get()` | `{ row: {...} }` |
|
|
||||||
| `delete` | `db.delete(table).where(id=...).run()` | `{ deleted: true }` |
|
|
||||||
| `vector_upsert` | LanceDB upsert with pre-computed vector | `{ ok: true }` |
|
|
||||||
| `vector_search` | LanceDB search by vector | `{ results: [{id, content, score}...] }` |
|
|
||||||
|
|
||||||
**Electron generates IDs + timestamps.** Backend tools never send `id` or `createdAt` in `insert` data — Electron adds `id: uuid()`, `createdAt: Date.now()`, `updatedAt: Date.now()`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## SQLite Schema Reference (Electron's local database)
|
|
||||||
|
|
||||||
Tools must use **camelCase** field names (Drizzle maps them to snake_case internally):
|
|
||||||
|
|
||||||
| Table | Columns |
|
|
||||||
|---|---|
|
|
||||||
| `tasks` | id, projectId, title, description, status (todo\|in_progress\|done), priority (high\|medium\|low), assignee (JSON array string), dueDate (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
|
||||||
| `projects` | id, clientId, name, status (active\|archived), aiSummary, createdAt (ms) |
|
|
||||||
| `checkpoints` | id, projectId (required), title, date (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
|
||||||
| `notes` | id, projectId, title, content (markdown), createdAt (ms), updatedAt (ms) |
|
|
||||||
| `taskComments` | id, taskId, author, content, createdAt (ms) |
|
|
||||||
| `clients` | id, parentId, name, industry, createdAt (ms) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase B — Backend Changes
|
|
||||||
|
|
||||||
### Step B.1 — WS context + frame types
|
|
||||||
- [x] Create `app/core/ws_context.py` (~25 lines):
|
|
||||||
- `_client_executor: ContextVar[Callable]` — holds the async callback for the current WS session
|
|
||||||
- `async def execute_on_client(action, table=None, data=None, filters=None, vector=None, limit=None) -> dict`:
|
|
||||||
- Reads callback from ContextVar
|
|
||||||
- Builds `tool_call` payload: `{id: str(uuid4()), action, table, data, filters, vector, limit}` (omits None fields)
|
|
||||||
- Calls `await callback(payload)` — which sends the WS frame and waits for `tool_result`
|
|
||||||
- Returns the result dict
|
|
||||||
- `def set_client_executor(fn)` / `def clear_client_executor()` — ContextVar management
|
|
||||||
- [x] Add to `app/schemas.py`:
|
|
||||||
- `WsFrameType(str, Enum)`: `chat_request`, `text_chunk`, `tool_call`, `tool_result`, `final`, `ping`
|
|
||||||
- `WsToolCall(BaseModel)`: `type`, `id`, `action`, `table?`, `data?`, `filters?`, `vector?`, `limit?`
|
|
||||||
- `WsToolResult(BaseModel)`: `type`, `id`, `row?`, `rows?`, `results?`, `deleted?`, `ok?`, `error?`
|
|
||||||
- `WsTextChunk(BaseModel)`: `type`, `text`
|
|
||||||
- `WsFinal(BaseModel)`: `type`, `response`
|
|
||||||
- **Files:** `app/core/ws_context.py`, `app/schemas.py`
|
|
||||||
- **Outcome:** Any tool can `await execute_on_client(...)` to query/mutate the user's local DB.
|
|
||||||
|
|
||||||
### Step B.2 — Rewrite all 23 tools to use `execute_on_client()`
|
|
||||||
- [x] Each tool: same `@tool` decorator, same parameters, same docstring. Replace `return json.dumps({...})` body with:
|
|
||||||
1. Call `result = await execute_on_client(action=..., table=..., data/filters=...)`
|
|
||||||
2. Return human-readable string with confirmation + key data from `result`
|
|
||||||
|
|
||||||
- [x] **`app/agents/task_agent.py` (8 tools):**
|
|
||||||
- `list_tasks(project_id, status, search, order_by)`:
|
|
||||||
```python
|
|
||||||
result = await execute_on_client(action="select", table="tasks", filters={
|
|
||||||
"projectId": project_id or None,
|
|
||||||
"status": status or None,
|
|
||||||
"search": search or None,
|
|
||||||
"orderBy": order_by or None,
|
|
||||||
})
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No tasks found matching the given filters."
|
|
||||||
lines = [f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" for r in rows]
|
|
||||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
|
||||||
```
|
|
||||||
- `create_task(title, ...)`:
|
|
||||||
```python
|
|
||||||
result = await execute_on_client(action="insert", table="tasks", data={
|
|
||||||
"title": title, "description": description or None, "status": status,
|
|
||||||
"priority": priority, "assignee": assignees, "dueDate": due_date or None,
|
|
||||||
"projectId": project_id or None, "isAiSuggested": is_ai_suggested, "isApproved": is_approved,
|
|
||||||
})
|
|
||||||
row = result["row"]
|
|
||||||
return f"Task created: '{row['title']}' (id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
|
||||||
```
|
|
||||||
- `update_task(task_id, ...)`: build updates dict (same logic as now) → `execute_on_client(action="update", table="tasks", data={"id": task_id, "updates": updates})` → return "Task updated: {title}"
|
|
||||||
- `delete_task(task_id)`: `execute_on_client(action="delete", table="tasks", data={"id": task_id})` → return "Task deleted"
|
|
||||||
- `list_tasks_due_today()`: calculate today's start/end ms → `execute_on_client(action="select", table="tasks", filters={"dueDateFrom": start, "dueDateTo": end})` → format + return
|
|
||||||
- `list_task_comments(task_id)`: `execute_on_client(action="select", table="taskComments", filters={"taskId": task_id})` → format + return
|
|
||||||
- `add_task_comment(task_id, author, content)`: `execute_on_client(action="insert", table="taskComments", data={...})` → return confirmation
|
|
||||||
- `delete_task_comment(comment_id)`: `execute_on_client(action="delete", table="taskComments", data={"id": comment_id})` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/project_agent.py` (6 tools):**
|
|
||||||
- `list_projects(client_id, include_archived)`: `execute_on_client(action="select", table="projects", filters={clientId, includeArchived})` → format + return
|
|
||||||
- `list_all_projects()`: `execute_on_client(action="select", table="projects")` → format + return
|
|
||||||
- `get_project(project_id)`: `execute_on_client(action="get", table="projects", data={"id": project_id})` → return project details or "not found"
|
|
||||||
- `create_project(name, client_id)`: `execute_on_client(action="insert", table="projects", data={name, clientId})` → return confirmation + id
|
|
||||||
- `update_project(project_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
|
||||||
- `delete_project(project_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/checkpoint_agent.py` (4 tools):**
|
|
||||||
- `list_checkpoints(project_id)`: `execute_on_client(action="select", table="checkpoints", filters={projectId})` → format + return
|
|
||||||
- `create_checkpoint(project_id, title, date, ...)`: `execute_on_client(action="insert", table="checkpoints", data={...})` → return confirmation + id
|
|
||||||
- `update_checkpoint(checkpoint_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
|
||||||
- `delete_checkpoint(checkpoint_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/note_agent.py` (5 tools):**
|
|
||||||
- `list_notes(project_id)`: `execute_on_client(action="select", table="notes", filters={projectId})` → format + return
|
|
||||||
- `get_note(note_id)`: `execute_on_client(action="get", table="notes", data={"id": note_id})` → return full content or "not found"
|
|
||||||
- `create_note(title, content, project_id)`: `execute_on_client(action="insert", table="notes", data={...})` → then `execute_on_client(action="vector_upsert", data={id, projectId, content}, vector=await embed(content))` → return confirmation
|
|
||||||
- `update_note(note_id, ...)`: build updates → `execute_on_client(action="update", ...)` → then vector_upsert for updated content → return confirmation
|
|
||||||
- `delete_note(note_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- **Files:** `app/agents/task_agent.py`, `app/agents/project_agent.py`, `app/agents/checkpoint_agent.py`, `app/agents/note_agent.py`
|
|
||||||
- **Outcome:** All 23 tools query real user data via WS. LLM sees actual rows, not action descriptors.
|
|
||||||
|
|
||||||
### Step B.3 — Bidirectional WebSocket handler
|
|
||||||
- [x] Refactor `app/api/routes/chat.py` WS endpoint:
|
|
||||||
- After auth + accept + receive `chat_request`:
|
|
||||||
1. Create `execute_on_client` callback closure capturing the websocket:
|
|
||||||
```python
|
|
||||||
pending_calls: dict[str, asyncio.Future] = {}
|
|
||||||
|
|
||||||
async def on_client_result(frame: dict):
|
|
||||||
"""Called when a tool_result frame arrives from Electron."""
|
|
||||||
fut = pending_calls.pop(frame["id"], None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(frame)
|
|
||||||
|
|
||||||
async def execute_callback(payload: dict) -> dict:
|
|
||||||
"""Send tool_call to Electron, wait for tool_result."""
|
|
||||||
call_id = payload["id"]
|
|
||||||
fut = asyncio.get_event_loop().create_future()
|
|
||||||
pending_calls[call_id] = fut
|
|
||||||
await websocket.send_text(json.dumps({"type": "tool_call", **payload}))
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
```
|
|
||||||
2. Set `client_executor` ContextVar with `execute_callback`
|
|
||||||
3. Run orchestrator in a task — it calls agents, agents call tools, tools call `execute_on_client()` which goes through the callback
|
|
||||||
4. In parallel, run a message receive loop that dispatches incoming frames:
|
|
||||||
- `tool_result` → `on_client_result(frame)`
|
|
||||||
- `ping` → ignore
|
|
||||||
5. Orchestrator yields `text_chunk` frames → send to client
|
|
||||||
6. Send `final` frame when done
|
|
||||||
7. Clear ContextVar
|
|
||||||
- Keep heartbeat ping every 30s
|
|
||||||
- 30s timeout on `tool_result` — if Electron doesn't respond, future raises `TimeoutError`, tool returns error string to LLM
|
|
||||||
- **Files:** `app/api/routes/chat.py`
|
|
||||||
- **Outcome:** Full bidirectional WS. Tool calls and text streaming happen concurrently on the same connection.
|
|
||||||
|
|
||||||
### Step B.4 — `_tool_loop` — no changes needed
|
|
||||||
- [x] Verify `app/core/agent_registry.py` works unchanged:
|
|
||||||
- `_tool_loop` calls `tool_fn.ainvoke(args)` → tool awaits `execute_on_client()` (WS round-trip) → returns string → `ToolMessage(content=string)` → LLM sees real data
|
|
||||||
- The async WS round-trip happens inside each tool. `_tool_loop` just sees an awaited tool returning a string — same as before, different content.
|
|
||||||
- **No code changes.** Just verify + add a log line for tool execution times if desired.
|
|
||||||
|
|
||||||
### Step B.5 — Orchestrator cleanup
|
|
||||||
- [x] Update `app/core/orchestrator.py`:
|
|
||||||
- `orchestrate_stream()`: remove `"actions": []` from final frame. Final becomes: `{"done": true, "response": "..."}`
|
|
||||||
- No other changes — `classify_intent` → `call_agent` → chunk response → final frame
|
|
||||||
- **Files:** `app/core/orchestrator.py`
|
|
||||||
- **Outcome:** Clean final frame. No more action descriptors in the protocol.
|
|
||||||
|
|
||||||
### Step B.6 — Add `/vectors/embed` endpoint
|
|
||||||
- [x] Add to `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/embed`:
|
|
||||||
- Request: `{ text: str }`
|
|
||||||
- Response: `{ vector: list[float] }` (1536-dim from `text-embedding-3-small`)
|
|
||||||
- Auth required (JWT)
|
|
||||||
- Used by:
|
|
||||||
- Backend tools: `note_agent` calls this before `vector_upsert`
|
|
||||||
- Electron: `vectordb.ts` calls this for note embedding on create/update
|
|
||||||
- **Files:** `app/api/routes/vectors.py`
|
|
||||||
- **Outcome:** Single embedding endpoint. Both backend tools and Electron can generate vectors.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Verification
|
|
||||||
|
|
||||||
| What to test | How |
|
|
||||||
|---|---|
|
|
||||||
| **Read flow** | "List my tasks" → `list_tasks` → `tool_call{select, tasks}` → Electron returns rows → LLM describes real tasks |
|
|
||||||
| **Write flow** | "Create a task called Buy milk" → `create_task` → `tool_call{insert, tasks, data:{title:"Buy milk"}}` → Electron inserts + returns row → tool confirms with id |
|
|
||||||
| **Multi-tool** | "How many todo tasks do I have?" → `list_tasks(status=todo)` → LLM counts actual rows → "You have 3 todo tasks" |
|
|
||||||
| **Vector search** | "Find notes about deployment" → tool embeds → `tool_call{vector_search, vector:[...]}` → Electron searches LanceDB → returns matching notes |
|
|
||||||
| **Vector upsert** | "Create a note about..." → insert note → vector_upsert with embedding → both SQLite + LanceDB updated |
|
|
||||||
| **Tool timeout** | Disconnect Electron mid-conversation → 30s timeout → tool returns error → LLM handles gracefully |
|
|
||||||
| **Concurrent calls** | Agent calls 2 tools in sequence → each does WS round-trip → both succeed → LLM sees both results |
|
|
||||||
| **_tool_loop max iter** | Verify 5-iteration limit still works → after 5 tool calls, LLM forced to answer without tools |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Execution Notes
|
|
||||||
|
|
||||||
- **Phase 1 is the critical path.** Auth + backend client + drizzle executor + orchestrator refactor must land first.
|
|
||||||
- **Steps 1.1–1.4 are additive** — existing app keeps working until Step 1.5 swaps the orchestrator.
|
|
||||||
- **Step 2.1 is the point of no return** — after removing LangChain, there's no local AI fallback.
|
|
||||||
- **Phase B (backend changes) must land before Phase 1.3–1.5** — Electron needs the bidirectional WS to talk to.
|
|
||||||
- **Phase 3 and Phase 4 are independent** — can be parallelized after Phase 2.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3 — Agent System: Config, Orchestration & Cloud Connectors
|
|
||||||
|
|
||||||
> **Objective:** Backend manages all agent configuration, scheduling, orchestration, and cloud data fetching. Two agent types: **Local Directory Agent** (backend triggers Electron to read files, then AI analyzes) and **Cloud Connector Agent** (backend fetches Gmail/Teams data directly, AI analyzes, pushes results to Electron via WS tool_call). All extracted items use existing WS tool infrastructure to insert into Electron's local DB with `is_ai_suggested=True`.
|
|
||||||
>
|
|
||||||
> **Electron Phase 3 plan:** `../adiuva/AI_REFACTOR_PLAN.md` Phase 3 section.
|
|
||||||
>
|
|
||||||
> **Electron UI status (2025):** Steps 3.6, 3.7, 3.8 of the Electron plan are ✅ complete. Agents are configured inside the Settings page (`/settings?section=agents`) — not a standalone route. The `JourneyDialog` (Step 3.8) is embedded inline in the Settings → Agents section. `LocalAgentConfigPanel` and `CloudAgentConfigPanel` (Step 3.7) are also inline. This affects the journey API contract (see Step 3.5 below).
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
Local Agent:
|
|
||||||
Scheduler/manual trigger ──► check device online ──► WS agent_run → Electron
|
|
||||||
──► Electron reads files ──► WS agent_data → Backend
|
|
||||||
──► Backend AI (prompt_template + file content) ──► WS tool_call(insert) → Electron
|
|
||||||
──► Electron persists with isAiSuggested=1
|
|
||||||
|
|
||||||
Cloud Agent:
|
|
||||||
Scheduler/manual trigger ──► Backend fetches Gmail/Teams (OAuth) ──► Backend AI analyzes
|
|
||||||
──► check device online ──► WS tool_call(insert) → Electron ──► Electron persists
|
|
||||||
```
|
|
||||||
|
|
||||||
**New WS frame types:**
|
|
||||||
|
|
||||||
| Direction | `type` | Payload |
|
|
||||||
|---|---|---|
|
|
||||||
| Server → Client | `agent_run` | `{ run_id, agent_id, config: { paths, file_extensions, prompt_template, data_types } }` |
|
|
||||||
| Client → Server | `agent_data` | `{ run_id, files: [{ path, name, content, metadata }] }` |
|
|
||||||
| Client → Server | `agent_complete` | `{ run_id, files_read, errors }` |
|
|
||||||
| Client → Server | `device_hello` | `{ device_id, agent_ids }` |
|
|
||||||
|
|
||||||
### Step 3.1 — Agent config tables
|
|
||||||
- [x] Add to `app/models.py`:
|
|
||||||
- **`LocalAgentConfig`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `device_id` str — identifies which Electron install this config belongs to
|
|
||||||
- `name` str
|
|
||||||
- `directory_paths` JSON — list of absolute paths on the device
|
|
||||||
- `data_types` JSON — which tables to extract to: `["tasks", "notes", "checkpoints", "projects"]`
|
|
||||||
- `prompt_template` text — user-configured via Chatbot Journey
|
|
||||||
- `file_extensions` JSON — e.g. `[".eml", ".txt", ".pdf", ".md"]`
|
|
||||||
- `schedule_cron` str — e.g. `"0 */6 * * *"` (every 6h)
|
|
||||||
- `enabled` bool (default True)
|
|
||||||
- `last_run_at` datetime nullable
|
|
||||||
- `created_at`, `updated_at` timestamps
|
|
||||||
- **`CloudAgentConfig`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `provider` str — enum: `gmail`, `teams`, `outlook`
|
|
||||||
- `name` str
|
|
||||||
- `data_types` JSON — same format as local
|
|
||||||
- `prompt_template` text
|
|
||||||
- `oauth_token_encrypted` text — Fernet-encrypted OAuth2 credentials
|
|
||||||
- `schedule_cron` str
|
|
||||||
- `enabled` bool (default True)
|
|
||||||
- `last_run_at` datetime nullable
|
|
||||||
- `filter_config` JSON — provider-specific: `{ labels: [], date_range: {from, to}, senders: [] }`
|
|
||||||
- `created_at`, `updated_at` timestamps
|
|
||||||
- **`AgentRunLog`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `agent_id` str — references LocalAgentConfig.id or CloudAgentConfig.id
|
|
||||||
- `agent_type` str — `local` or `cloud`
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `status` str — `running`, `success`, `error`, `partial`
|
|
||||||
- `items_processed` int (default 0)
|
|
||||||
- `items_created` int (default 0)
|
|
||||||
- `errors` JSON — list of error strings
|
|
||||||
- `started_at` datetime
|
|
||||||
- `completed_at` datetime nullable
|
|
||||||
- [x] Add Pydantic schemas to `app/schemas.py`:
|
|
||||||
- `LocalAgentConfigCreate`, `LocalAgentConfigUpdate`, `LocalAgentConfigResponse`
|
|
||||||
- `CloudAgentConfigCreate`, `CloudAgentConfigUpdate`, `CloudAgentConfigResponse`
|
|
||||||
- `AgentRunLogResponse`
|
|
||||||
- `AgentCatalogItem` — `{ type, name, description, config_schema }`
|
|
||||||
- `WsAgentRun`, `WsAgentData`, `WsAgentComplete`, `WsDeviceHello`
|
|
||||||
- [x] Generate Alembic migration
|
|
||||||
- **Files:** `app/models.py`, `app/schemas.py`, `alembic/versions/`
|
|
||||||
- **Outcome:** Agent config and run tracking tables in PostgreSQL.
|
|
||||||
|
|
||||||
### Step 3.2 — Agent CRUD API routes
|
|
||||||
- [x] Create `app/api/routes/agents.py`:
|
|
||||||
- `GET /api/v1/agents/catalog` — returns hardcoded agent type catalog:
|
|
||||||
- `local_directory`: "Watches local directories, extracts data from files using AI"
|
|
||||||
- `gmail`: "Scans Gmail inbox, extracts tasks/notes from emails"
|
|
||||||
- `teams`: "Monitors Teams messages, extracts action items"
|
|
||||||
- `outlook`: "Scans Outlook inbox, extracts tasks/notes"
|
|
||||||
- `GET /api/v1/agents/local` — list user's local agent configs
|
|
||||||
- `POST /api/v1/agents/local` — create local agent config
|
|
||||||
- Body: `{ name, device_id, directory_paths, data_types, prompt_template, file_extensions, schedule_cron }`
|
|
||||||
- Tier check: count enabled agents ≤ `batch_active` limit
|
|
||||||
- `PUT /api/v1/agents/local/{id}` — update config (ownership check)
|
|
||||||
- `DELETE /api/v1/agents/local/{id}` — delete config + associated run logs
|
|
||||||
- `GET /api/v1/agents/cloud` — list user's cloud agent configs
|
|
||||||
- `POST /api/v1/agents/cloud` — create cloud connector config
|
|
||||||
- Body: `{ provider, name, data_types, prompt_template, oauth_token_encrypted, schedule_cron, filter_config }`
|
|
||||||
- Tier check: same `batch_active` limit (local + cloud count together)
|
|
||||||
- `PUT /api/v1/agents/cloud/{id}` — update config
|
|
||||||
- `DELETE /api/v1/agents/cloud/{id}` — delete config + run logs
|
|
||||||
- `GET /api/v1/agents/runs` — query params: `agent_id`, `page`, `limit` → paginated run logs
|
|
||||||
- `POST /api/v1/agents/{id}/run` — manual trigger (dispatches to agent runner)
|
|
||||||
- All routes require JWT auth; ownership enforced on all mutations
|
|
||||||
- [x] Register router in `app/main.py`
|
|
||||||
- **Files:** `app/api/routes/agents.py`, `app/main.py`
|
|
||||||
- **Outcome:** Full CRUD for agent configs with tier-gated creation limits.
|
|
||||||
|
|
||||||
### Step 3.3 — Device WS endpoint
|
|
||||||
- [x] Create `app/api/routes/device_ws.py`:
|
|
||||||
- `WebSocket /api/v1/ws/device?token=<jwt>` — persistent connection from Electron
|
|
||||||
- On connect:
|
|
||||||
- Authenticate JWT
|
|
||||||
- Receive `device_hello` frame → extract `device_id`, `agent_ids`
|
|
||||||
- Store connection in `DeviceConnectionManager` (in-memory dict: `user_id → { ws, device_id }`)
|
|
||||||
- Check for overdue agent runs → trigger them immediately
|
|
||||||
- Message loop:
|
|
||||||
- `agent_data` → route to active agent run handler
|
|
||||||
- `agent_complete` → finalize agent run
|
|
||||||
- `tool_result` → route to pending tool call (same pattern as chat WS)
|
|
||||||
- `pong` → heartbeat ack
|
|
||||||
- On disconnect:
|
|
||||||
- Remove from `DeviceConnectionManager`
|
|
||||||
- Mark any in-progress agent runs as `error` with "device disconnected"
|
|
||||||
- Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s
|
|
||||||
- [x] Create `app/core/device_manager.py`:
|
|
||||||
- `DeviceConnectionManager` (singleton):
|
|
||||||
- `register(user_id, device_id, ws)` — stores active connection
|
|
||||||
- `unregister(user_id)` — removes connection
|
|
||||||
- `get_ws(user_id) -> WebSocket | None` — returns active WS if device is online
|
|
||||||
- `is_online(user_id, device_id=None) -> bool` — optionally checks specific device
|
|
||||||
- `send_frame(user_id, frame: dict)` — sends JSON frame to device
|
|
||||||
- **Files:** `app/api/routes/device_ws.py`, `app/core/device_manager.py`, `app/main.py`
|
|
||||||
- **Outcome:** Backend maintains persistent WS connections to Electron devices for agent triggers.
|
|
||||||
|
|
||||||
### Step 3.4 — Agent run orchestrator
|
|
||||||
- [x] Create `app/core/agent_runner.py`:
|
|
||||||
- `async run_local_agent(user_id, config: LocalAgentConfig, device_mgr: DeviceConnectionManager)`:
|
|
||||||
1. Check device is online with matching `device_id` → abort if offline
|
|
||||||
2. Create `AgentRunLog` with `status=running`
|
|
||||||
3. Send `WsAgentRun` frame to Electron with config (paths, extensions, prompt)
|
|
||||||
4. Await `WsAgentData` frames — collect file contents
|
|
||||||
5. Await `WsAgentComplete` frame — Electron signals done reading
|
|
||||||
6. For each file: call LLM with `prompt_template` + file content → extract structured items
|
|
||||||
7. For each extracted item: send `WsToolCall(insert, table, data)` to Electron → await `WsToolResult`
|
|
||||||
- All inserts include `is_ai_suggested=True, is_approved=False`
|
|
||||||
8. Update `AgentRunLog`: `status=success`, `items_processed`, `items_created`
|
|
||||||
- `async run_cloud_agent(user_id, config: CloudAgentConfig, device_mgr: DeviceConnectionManager)`:
|
|
||||||
1. Check device is online → abort if offline (results must push to Electron)
|
|
||||||
2. Create `AgentRunLog` with `status=running`
|
|
||||||
3. Decrypt OAuth credentials from `config.oauth_token_encrypted`
|
|
||||||
4. Fetch data from cloud provider (Step 3.6):
|
|
||||||
- Gmail: `google-api-python-client` + `filter_config` label/date filters
|
|
||||||
- Teams: `msgraph-sdk` + channel/date filters
|
|
||||||
- Outlook: `msgraph-sdk` + folder/date filters
|
|
||||||
5. For each item: call LLM with `prompt_template` + email/message content → extract structured items
|
|
||||||
6. For each extracted item: send `WsToolCall(insert)` to Electron → await `WsToolResult`
|
|
||||||
7. Update `AgentRunLog`
|
|
||||||
- `async trigger_pending_runs(user_id, device_id, device_mgr)`:
|
|
||||||
- Called when Electron connects (after `device_hello`)
|
|
||||||
- Queries all enabled agent configs where `last_run_at + schedule_interval < now()`
|
|
||||||
- For local agents: only triggers if `config.device_id == device_id`
|
|
||||||
- For cloud agents: triggers regardless of device (any connected device can receive results)
|
|
||||||
- Executes runs sequentially (one at a time to avoid overwhelming the WS)
|
|
||||||
- Error handling: on any failure, update `AgentRunLog` with `status=error` + error details
|
|
||||||
- [x] Wire `POST /agents/{id}/run` endpoint to dispatch background task via `asyncio.create_task()`
|
|
||||||
- [x] Replace `_trigger_pending_runs_stub` in `device_ws.py` with real `trigger_pending_runs` call
|
|
||||||
- [x] Add `croniter>=3.0.0` to `requirements.txt`
|
|
||||||
- [x] 23 unit + integration tests covering all code paths
|
|
||||||
- **Files:** `app/core/agent_runner.py`, `app/api/routes/agents.py`, `app/api/routes/device_ws.py`, `requirements.txt`, `tests/test_agent_runner.py`
|
|
||||||
- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls — stub until Step 3.6).
|
|
||||||
|
|
||||||
### Step 3.5 — Chatbot Journey endpoint
|
|
||||||
- [x] Create `app/api/routes/agent_setup.py`:
|
|
||||||
- `POST /api/v1/agents/journey/start`:
|
|
||||||
- Body: `{ agent_type: "local"|"cloud", agent_id: str | None }`
|
|
||||||
- `agent_type`: which kind of agent this journey configures.
|
|
||||||
- `agent_id`: optional — if provided, the session is pre-seeded with the existing agent's `prompt_template` so the user can refine it. If absent, fresh journey.
|
|
||||||
- **No `data_types` field** — data types are determined through the conversation itself, not sent upfront.
|
|
||||||
- Creates a journey session (in-memory or Redis-backed)
|
|
||||||
- Returns first AI message: contextual question based on agent type
|
|
||||||
- Local: "What kind of files are in the directories you want to monitor? (emails, documents, logs, etc.)"
|
|
||||||
- Cloud: "What kind of emails/messages should I look for? (client communications, invoices, meeting notes, etc.)"
|
|
||||||
- Response: `{ session_id, message, done: false }`
|
|
||||||
- **Electron note:** `proxyPost` auto-converts camelCase keys to snake_case. Electron sends `{ agentType, agentId }` → backend receives `{ agent_type, agent_id }`.
|
|
||||||
- `POST /api/v1/agents/journey/message`:
|
|
||||||
- Body: `{ session_id, message }`
|
|
||||||
- AI processes user's answer, asks follow-up questions (max 5 turns)
|
|
||||||
- System prompt: "You are configuring a data extraction agent for a freelancer. Ask about file format, what data to extract (tasks, notes, checkpoints), naming conventions, priority rules, and any special mapping. After 3-5 questions, generate a detailed prompt_template."
|
|
||||||
- When AI determines enough context: `{ session_id, message: "Here's your configuration...", done: true, prompt_template: "..." }`
|
|
||||||
- The `prompt_template` is a structured instruction for the extraction LLM (e.g. "Extract tasks from email. Subject becomes task title. If body contains 'urgent' or 'ASAP', set priority to 'high'. Extract due dates if mentioned.")
|
|
||||||
- **Electron note:** `toCamelCase` converts the response → Electron reads `promptTemplate` from the final message and auto-fills the agent config panel. User clicks "Save & apply" which calls `agent.local.update` / `agent.cloud.update` tRPC mutation.
|
|
||||||
- **Files:** `app/api/routes/agent_setup.py`, `app/main.py`
|
|
||||||
- **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅
|
|
||||||
|
|
||||||
### Step 3.6 — Cloud provider integrations
|
|
||||||
- [ ] Create `app/integrations/gmail.py`:
|
|
||||||
- `GmailClient`:
|
|
||||||
- `__init__(oauth_token)` — initializes Google API client
|
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]`
|
|
||||||
- `EmailMessage`: `{ id, subject, sender, body_text, date, labels }`
|
|
||||||
- Handles token refresh via Google OAuth2 refresh flow
|
|
||||||
- Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders`
|
|
||||||
- [ ] Create `app/integrations/ms_graph.py`:
|
|
||||||
- `MSGraphClient`:
|
|
||||||
- `__init__(oauth_token)` — initializes MS Graph client
|
|
||||||
- `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook)
|
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams)
|
|
||||||
- `ChatMessage`: `{ id, content, sender, channel, date }`
|
|
||||||
- Handles token refresh via MSAL
|
|
||||||
- [ ] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient`
|
|
||||||
- **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal`
|
|
||||||
- **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py`
|
|
||||||
- **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams.
|
|
||||||
|
|
||||||
### Step 3.7 — Agent scheduler
|
|
||||||
- [ ] Create `app/core/agent_scheduler.py`:
|
|
||||||
- Uses `APScheduler` (or simple asyncio loop) to check agent schedules
|
|
||||||
- Every 60s: query enabled agents where `last_run_at + cron_interval < now()`
|
|
||||||
- For each due agent:
|
|
||||||
- Check if user's device is online via `DeviceConnectionManager`
|
|
||||||
- If online: dispatch to `agent_runner`
|
|
||||||
- If offline: skip (will trigger on next `device_hello`)
|
|
||||||
- Locks: use PostgreSQL advisory locks to prevent duplicate runs in multi-instance deployments
|
|
||||||
- [ ] Integrate with FastAPI lifespan (start scheduler on app startup, shutdown gracefully)
|
|
||||||
- **Dependencies:** `apscheduler>=4.0`
|
|
||||||
- **Files:** `app/core/agent_scheduler.py`, `app/main.py`
|
|
||||||
- **Outcome:** Agents run automatically on their configured schedules.
|
|
||||||
|
|
||||||
### Step 3.8 — OAuth flow endpoints
|
|
||||||
- [ ] Create `app/api/routes/oauth.py`:
|
|
||||||
- `GET /api/v1/oauth/{provider}/authorize` — returns OAuth authorization URL
|
|
||||||
- Gmail: Google OAuth2 with `gmail.readonly` scope
|
|
||||||
- Outlook/Teams: MS identity platform with `Mail.Read`, `ChannelMessage.Read.All` scopes
|
|
||||||
- `GET /api/v1/oauth/{provider}/callback` — handles OAuth redirect
|
|
||||||
- Exchanges auth code for access + refresh tokens
|
|
||||||
- Encrypts tokens with Fernet (server-side key from settings)
|
|
||||||
- Returns encrypted token blob for storage in `CloudAgentConfig.oauth_token_encrypted`
|
|
||||||
- `POST /api/v1/oauth/{provider}/refresh` — refresh expired OAuth token
|
|
||||||
- **Files:** `app/api/routes/oauth.py`, `app/main.py`
|
|
||||||
- **Outcome:** Users can connect Gmail/Teams/Outlook accounts securely.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Phase 3 — Verification
|
|
||||||
|
|
||||||
| # | Scenario | Expected |
|
|
||||||
|---|---|---|
|
|
||||||
| 1 | **Agent CRUD** | Create/read/update/delete local and cloud configs; tier limits enforced (free=2, pro=10) |
|
|
||||||
| 2 | **WS device connect** | Electron connects → `device_hello` → backend stores connection → triggers overdue runs |
|
|
||||||
| 3 | **Local agent run** | Backend sends `agent_run` → Electron reads files → `agent_data` → backend AI extracts → `tool_call(insert)` → Electron persists with `isAiSuggested=1` |
|
|
||||||
| 4 | **Cloud agent run** | Backend fetches Gmail → AI extracts tasks → `tool_call(insert)` → Electron persists |
|
|
||||||
| 5 | **Device binding** | Local agent config with `device_id=A` only triggers when device A is connected |
|
|
||||||
| 6 | **Chatbot Journey** | Start journey → 3-5 Q&A turns → produces valid `prompt_template` |
|
|
||||||
| 7 | **Schedule** | Agent with `schedule_cron="0 */6 * * *"` runs every 6h when device is online |
|
|
||||||
| 8 | **Offline resilience** | Device offline → runs skipped → device reconnects → overdue runs trigger immediately |
|
|
||||||
| 9 | **OAuth flow** | Gmail authorize → callback → token encrypted → stored in config → fetch emails works |
|
|
||||||
|
|
||||||
### Phase 3 — New Dependencies
|
|
||||||
|
|
||||||
| Package | Purpose |
|
|
||||||
|---|---|
|
|
||||||
| `google-api-python-client` | Gmail API access |
|
|
||||||
| `google-auth-oauthlib` | Gmail OAuth2 flow |
|
|
||||||
| `msgraph-sdk` | Outlook + Teams API access |
|
|
||||||
| `msal` | MS identity platform auth |
|
|
||||||
| `apscheduler>=4.0` | Agent scheduling |
|
|
||||||
| `cryptography` (Fernet) | OAuth token encryption at rest |
|
|
||||||
- **One step at a time.** Mark `[x]` and commit with `step N.N complete: <outcome>`.
|
|
||||||
572
BACKEND_PLAN.md
572
BACKEND_PLAN.md
@@ -1,572 +0,0 @@
|
|||||||
# Backend Plan — Adiuva Cloud API
|
|
||||||
|
|
||||||
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
|
||||||
>
|
|
||||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
|
|
||||||
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── app/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── main.py # FastAPI entry + CORS + lifespan + router includes
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── agent_registry.py # Base classes + singleton registry
|
|
||||||
│ │ ├── orchestrator.py # LLM-based intent router
|
|
||||||
│ │ ├── execution_plan.py # Plan builder + cache
|
|
||||||
│ │ └── plugin_loader.py # Dynamic agent loading
|
|
||||||
│ ├── agents/ # Chat agents (proprietary logic + prompts)
|
|
||||||
│ │ ├── __init__.py # Auto-registers all agents
|
|
||||||
│ │ ├── task_agent.py
|
|
||||||
│ │ ├── calendar_agent.py
|
|
||||||
│ │ ├── email_agent.py
|
|
||||||
│ │ └── analytics_agent.py
|
|
||||||
│ ├── api/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── routes/
|
|
||||||
│ │ │ ├── __init__.py
|
|
||||||
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
|
||||||
│ │ │ ├── plans.py # GET /plans/playbook
|
|
||||||
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
|
|
||||||
│ │ │ ├── vectors.py # Upsert/search cloud vector store
|
|
||||||
│ │ │ ├── backup.py # PUT/GET /backup
|
|
||||||
│ │ │ ├── plugins.py # Plugin marketplace
|
|
||||||
│ │ │ ├── auth.py # Register/login/refresh
|
|
||||||
│ │ │ └── billing.py # Checkout/webhook/subscription
|
|
||||||
│ │ └── middleware/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── auth.py # JWT validation
|
|
||||||
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
|
||||||
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
|
||||||
│ ├── storage/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
|
|
||||||
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
|
|
||||||
│ │ └── encryption.py # Integrity verification only — NO decryption
|
|
||||||
│ ├── marketplace/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
|
|
||||||
│ │ ├── plugin_review.py # Review queue + approval workflow
|
|
||||||
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
|
|
||||||
│ ├── billing/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
|
||||||
│ │ └── tier_manager.py # Feature matrix per tier
|
|
||||||
│ └── config/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ └── settings.py # Pydantic BaseSettings (env-based)
|
|
||||||
├── tests/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── conftest.py # Fixtures: test client, mock agents, mock LLM
|
|
||||||
│ ├── test_orchestrator.py
|
|
||||||
│ ├── test_agents.py
|
|
||||||
│ ├── test_auth.py
|
|
||||||
│ ├── test_backup.py
|
|
||||||
│ ├── test_storage.py
|
|
||||||
│ └── test_plugins.py
|
|
||||||
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
|
|
||||||
│ ├── alembic.ini
|
|
||||||
│ └── versions/
|
|
||||||
├── requirements.txt
|
|
||||||
├── Dockerfile
|
|
||||||
├── docker-compose.yml # App + PostgreSQL + Redis (dev)
|
|
||||||
├── .env.example
|
|
||||||
└── README.md
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step-by-Step Implementation
|
|
||||||
|
|
||||||
### Step 1 — Project scaffolding ✅
|
|
||||||
- [x] Initialize repo with the directory structure above
|
|
||||||
- [x] Write `requirements.txt`:
|
|
||||||
```
|
|
||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
langchain>=0.3.0
|
|
||||||
langchain-openai>=0.3.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
stripe>=11.0.0
|
|
||||||
boto3>=1.35.0
|
|
||||||
slowapi>=0.1.9
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
alembic>=1.14.0
|
|
||||||
bcrypt>=4.2.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
httpx>=0.28.0
|
|
||||||
websockets>=14.0
|
|
||||||
pytest>=8.0.0
|
|
||||||
pytest-asyncio>=0.24.0
|
|
||||||
```
|
|
||||||
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
|
||||||
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
|
||||||
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
|
||||||
- [x] Write `.env.example`
|
|
||||||
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
|
|
||||||
|
|
||||||
### Step 2 — Pydantic schemas (API contracts) ✅
|
|
||||||
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
|
||||||
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
|
||||||
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
|
||||||
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
|
||||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
|
|
||||||
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
|
||||||
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
|
||||||
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
|
||||||
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
|
||||||
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
|
||||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
|
||||||
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
|
|
||||||
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
|
|
||||||
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
|
|
||||||
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
|
|
||||||
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
|
|
||||||
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
|
|
||||||
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
|
|
||||||
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
|
|
||||||
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
|
|
||||||
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
|
|
||||||
- `PluginInstallRequest`: `plugin_id: str`
|
|
||||||
- **Outcome:** All request/response models defined and validated.
|
|
||||||
|
|
||||||
### Step 3 — Agent Registry + base classes ✅
|
|
||||||
- [x] `app/core/agent_registry.py`:
|
|
||||||
- `BaseAgent(ABC)`:
|
|
||||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
|
||||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
|
||||||
- `ChatAgent(BaseAgent)`:
|
|
||||||
- Abstract `async handle(query: str, context: dict) -> str`
|
|
||||||
- Abstract `get_tools() -> list` (LangChain tool definitions)
|
|
||||||
- Concrete `_tool_loop(llm, messages, tools, max_iter=5) -> str` — shared tool-calling loop
|
|
||||||
- `AgentRegistry` (singleton):
|
|
||||||
- `_agents: dict[str, ChatAgent]`
|
|
||||||
- `register(agent_class)` — decorator pattern
|
|
||||||
- `get(name) -> ChatAgent`
|
|
||||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
|
||||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
|
||||||
- [x] Unit tests: register, get, list, call_agent with mock
|
|
||||||
- **Outcome:** Pluggable agent framework.
|
|
||||||
|
|
||||||
### Step 4 — Orchestrator ✅
|
|
||||||
- [x] `app/core/orchestrator.py`:
|
|
||||||
- `async classify_intent(message, context, registry) -> str`:
|
|
||||||
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
|
||||||
- Uses gpt-4o-mini via LangChain for low latency
|
|
||||||
- Falls back to `task_agent` if no clear match
|
|
||||||
- `async route_single(agent_name, message, context) -> ChatResponse`:
|
|
||||||
- Instantiates agent from registry
|
|
||||||
- Calls `agent.handle(message, context)`
|
|
||||||
- Returns response + any actions the agent produced
|
|
||||||
- `async route_pipeline(agent_names, message, context) -> ChatResponse`:
|
|
||||||
- Executes agents in sequence
|
|
||||||
- Each agent receives `{...context, previous_results: [...]}`
|
|
||||||
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
|
||||||
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
|
||||||
- Main entry point
|
|
||||||
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
|
|
||||||
- Classifies intent
|
|
||||||
- If `execution_mode == 'direct'`: route + return response
|
|
||||||
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
|
||||||
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
|
||||||
- Same as orchestrate but yields tokens for WebSocket streaming
|
|
||||||
- [x] Integration tests with mocked LLM and mocked agents
|
|
||||||
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
|
||||||
|
|
||||||
### Step 5 — Execution Plan generator ✅
|
|
||||||
- [x] `app/core/execution_plan.py`:
|
|
||||||
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
|
||||||
- `ExecutionPlanBuilder`:
|
|
||||||
- `add_step(action, params) -> self`
|
|
||||||
- `add_llm_step(template_id, variables) -> self`
|
|
||||||
- `add_data_step(action, data_from_step) -> self`
|
|
||||||
- `build() -> ExecutionPlan` — validates step references
|
|
||||||
- `PlanCache`:
|
|
||||||
- In-memory LRU (maxsize=1000)
|
|
||||||
- `cache_plan(key, plan)`, `get_plan(key)`, `get_all_playbooks() -> list[ExecutionPlan]`
|
|
||||||
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
|
||||||
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
|
||||||
|
|
||||||
### Step 6 — Chat Agents ✅
|
|
||||||
- [x] `app/agents/task_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
|
|
||||||
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
|
|
||||||
- Accepts flexible context; sentinel `-1` for optional integer update fields
|
|
||||||
- [x] `app/agents/checkpoint_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages project checkpoints (milestones): list, create, update, delete"
|
|
||||||
- Tools (4): `list_checkpoints(project_id)`, `create_checkpoint(project_id, title, date, is_ai_suggested, is_approved)`, `update_checkpoint(checkpoint_id, ...)`, `delete_checkpoint(checkpoint_id)`
|
|
||||||
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
|
|
||||||
- [x] `app/agents/project_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
|
|
||||||
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
|
|
||||||
- [x] `app/agents/note_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages notes: list, get, create, update, delete"
|
|
||||||
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
|
|
||||||
- content is Markdown; `get_note` should be called before update to preserve existing content
|
|
||||||
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
|
|
||||||
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
|
|
||||||
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Checkpoints, Projects, Notes), all registered and tested.
|
|
||||||
|
|
||||||
### Step 7 — Storage Layer ✅
|
|
||||||
- [x] `app/storage/blob_store.py`:
|
|
||||||
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
|
|
||||||
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
|
|
||||||
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
|
|
||||||
- [x] `app/storage/vector_store.py`:
|
|
||||||
- `VectorStore`: `async upsert`, `async search`, `async delete`
|
|
||||||
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
|
|
||||||
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
|
|
||||||
- ANN on encrypted data: known accuracy trade-off, documented
|
|
||||||
- [x] `app/storage/encryption.py`:
|
|
||||||
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
|
|
||||||
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
|
|
||||||
- Backend NEVER holds decryption keys
|
|
||||||
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
|
|
||||||
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
|
|
||||||
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
|
|
||||||
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
|
|
||||||
|
|
||||||
### Step 8 — API Routes ✅
|
|
||||||
|
|
||||||
#### 8a — Chat endpoint
|
|
||||||
- [x] `app/api/routes/chat.py`:
|
|
||||||
- `POST /api/v1/chat`:
|
|
||||||
- Request: `ChatRequest`
|
|
||||||
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
|
||||||
- Response: `ChatResponse` or `ExecutionPlan`
|
|
||||||
- `WebSocket /api/v1/chat/stream`:
|
|
||||||
- Client sends `ChatRequest` as first JSON frame
|
|
||||||
- Server yields token strings via `orchestrate_stream()`
|
|
||||||
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
|
||||||
- Heartbeat ping every 30s to keep connection alive
|
|
||||||
|
|
||||||
#### 8b — Plans endpoint
|
|
||||||
- [x] `app/api/routes/plans.py`:
|
|
||||||
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
|
||||||
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
|
||||||
|
|
||||||
#### 8c — Storage endpoint (cloud records)
|
|
||||||
- [x] `app/api/routes/storage.py`:
|
|
||||||
- `POST /api/v1/storage/records`: Create encrypted record
|
|
||||||
- Request: `StorageRecordCreate`
|
|
||||||
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
|
|
||||||
- Response: `{id: str, created_at: int}`
|
|
||||||
- `GET /api/v1/storage/records`: List record metadata (no blobs)
|
|
||||||
- Query params: `table: str`, `page: int`, `limit: int`
|
|
||||||
- Response: `list[{id, table, checksum, created_at, updated_at}]`
|
|
||||||
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
|
|
||||||
- Response: blob bytes + `X-Checksum` header
|
|
||||||
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
|
|
||||||
- Request: `StorageRecordUpdate`
|
|
||||||
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
|
|
||||||
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
|
|
||||||
|
|
||||||
#### 8d — Vectors endpoint (cloud vector store)
|
|
||||||
- [x] `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/upsert`:
|
|
||||||
- Request: `VectorUpsertRequest`
|
|
||||||
- Verifies checksums, delegates to `VectorStore.upsert()`
|
|
||||||
- Response: `{upserted: int}`
|
|
||||||
- `POST /api/v1/storage/vectors/search`:
|
|
||||||
- Request: `VectorSearchRequest`
|
|
||||||
- Delegates to `VectorStore.search()`
|
|
||||||
- Response: `VectorSearchResponse`
|
|
||||||
- `DELETE /api/v1/storage/vectors`:
|
|
||||||
- Request: `{ids: list[str]}`
|
|
||||||
|
|
||||||
#### 8e — Backup endpoint
|
|
||||||
- [x] `app/api/routes/backup.py`:
|
|
||||||
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
|
||||||
- Free: 0 (no backup)
|
|
||||||
- Pro: 5 GB
|
|
||||||
- Power: 25 GB
|
|
||||||
- Team: unlimited
|
|
||||||
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
|
||||||
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
|
||||||
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
|
||||||
|
|
||||||
#### 8f — Plugins endpoint
|
|
||||||
- [x] `app/api/routes/plugins.py`:
|
|
||||||
- `GET /api/v1/plugins`:
|
|
||||||
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
|
|
||||||
- Response: `PluginListResponse`
|
|
||||||
- Available from Power tier and above
|
|
||||||
- `GET /api/v1/plugins/{id}`:
|
|
||||||
- Response: `PluginManifest` + ratings + install count
|
|
||||||
- `POST /api/v1/plugins/{id}/install`:
|
|
||||||
- Request: `PluginInstallRequest`
|
|
||||||
- Records installation for the user (billing tracking, analytics)
|
|
||||||
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
|
|
||||||
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
|
|
||||||
- `DELETE /api/v1/plugins/{id}/install`:
|
|
||||||
- Unregisters installation
|
|
||||||
|
|
||||||
#### 8g — Auth endpoint
|
|
||||||
- [x] `app/api/routes/auth.py`:
|
|
||||||
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
|
||||||
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
|
||||||
|
|
||||||
#### 8h — Billing endpoint
|
|
||||||
- [x] `app/api/routes/billing.py`:
|
|
||||||
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
|
||||||
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
|
||||||
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
|
||||||
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
|
||||||
|
|
||||||
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
|
|
||||||
|
|
||||||
### Step 9 — Middleware
|
|
||||||
|
|
||||||
#### 9a — Auth middleware
|
|
||||||
- [x] `app/api/middleware/auth.py`:
|
|
||||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
|
||||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
|
||||||
- Raises `401` on invalid/expired token
|
|
||||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
|
||||||
|
|
||||||
#### 9b — Rate limiter
|
|
||||||
- [x] `app/api/middleware/rate_limit.py`:
|
|
||||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
|
||||||
- Tier-based limits:
|
|
||||||
- Free: 20 req/min
|
|
||||||
- Pro: 60 req/min
|
|
||||||
- Power: 120 req/min
|
|
||||||
- Team: 200 req/seat/min
|
|
||||||
- Custom 429 response with `Retry-After` header
|
|
||||||
|
|
||||||
#### 9c — Sanitizer
|
|
||||||
- [x] `app/api/middleware/sanitizer.py`:
|
|
||||||
- Response middleware that scans response bodies
|
|
||||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
|
||||||
- Pattern-based detection + exact match against known prompt fingerprints
|
|
||||||
- Logs sanitization events for monitoring
|
|
||||||
|
|
||||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
|
||||||
|
|
||||||
### Step 10 — Plugin Marketplace ✅
|
|
||||||
- [x] `app/marketplace/plugin_registry.py`:
|
|
||||||
- `PluginRegistry`:
|
|
||||||
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
|
||||||
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
|
||||||
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
|
||||||
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
|
||||||
- `async reject_plugin(plugin_id, reason: str) -> None`
|
|
||||||
- [x] `app/marketplace/plugin_review.py`:
|
|
||||||
- `ReviewQueue`:
|
|
||||||
- `async get_pending() -> list[dict]`
|
|
||||||
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
|
||||||
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
|
||||||
- [x] `app/marketplace/revenue_share.py`:
|
|
||||||
- `RevenueShare`:
|
|
||||||
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
|
||||||
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
|
||||||
- `async get_earnings(developer_id, period) -> dict`
|
|
||||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
|
||||||
|
|
||||||
### Step 11 — Billing & Tier management ✅
|
|
||||||
- [x] `app/billing/stripe_service.py`:
|
|
||||||
- `create_checkout_session(user_id, tier) -> str`
|
|
||||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
|
||||||
- `get_subscription(user_id) -> dict | None`
|
|
||||||
- `cancel_subscription(user_id) -> None`
|
|
||||||
- [x] `app/billing/tier_manager.py`:
|
|
||||||
- `TierManager`:
|
|
||||||
- Feature matrix:
|
|
||||||
```python
|
|
||||||
FEATURES = {
|
|
||||||
'free': {
|
|
||||||
'agents': 3,
|
|
||||||
'batch_active': 2,
|
|
||||||
'cloud_storage_gb': 0,
|
|
||||||
'backup_gb': 0,
|
|
||||||
'providers': 1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'pro': {
|
|
||||||
'agents': -1, # unlimited
|
|
||||||
'batch_active': 10,
|
|
||||||
'cloud_storage_gb': 5,
|
|
||||||
'backup_gb': 5,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'power': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1, # unlimited
|
|
||||||
'cloud_storage_gb': 25,
|
|
||||||
'backup_gb': 25,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'team': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1,
|
|
||||||
'cloud_storage_gb': -1,
|
|
||||||
'backup_gb': -1,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- `get_tier(user_id) -> BillingTier`
|
|
||||||
- `check_feature(user_id, feature) -> bool`
|
|
||||||
- `get_rate_limit(tier) -> int`
|
|
||||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
|
||||||
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
|
|
||||||
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
|
|
||||||
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
|
|
||||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
|
||||||
|
|
||||||
### Step 12 — Database (auth/billing/marketplace only)
|
|
||||||
- [x] PostgreSQL schema via Alembic:
|
|
||||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
|
||||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
|
||||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
|
||||||
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
|
||||||
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
|
|
||||||
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
|
|
||||||
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
|
||||||
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
|
||||||
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
|
||||||
- [x] Initial Alembic migration
|
|
||||||
- [x] SQLAlchemy models in `app/models.py`
|
|
||||||
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
|
||||||
|
|
||||||
### Step 13 — Testing & deployment ✅
|
|
||||||
- [x] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
|
|
||||||
- [x] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
|
||||||
- [x] `tests/test_agents.py`: each agent with mocked tools
|
|
||||||
- [x] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
|
||||||
- [x] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
|
||||||
- [x] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
|
|
||||||
- [x] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
|
|
||||||
- [x] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
|
||||||
- [x] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
|
||||||
- **Outcome:** Fully tested, deployable backend.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API Contract Summary
|
|
||||||
|
|
||||||
| Method | Endpoint | Auth | Request | Response |
|
|
||||||
|--------|----------|------|---------|----------|
|
|
||||||
| POST | `/api/v1/auth/register` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/login` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/refresh` | No | `{refresh_token}` | `AuthTokens` |
|
|
||||||
| GET | `/api/v1/auth/me` | JWT | — | `UserProfile` |
|
|
||||||
| POST | `/api/v1/chat` | JWT | `ChatRequest` | `ChatResponse \| ExecutionPlan` |
|
|
||||||
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
|
||||||
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
|
||||||
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
|
||||||
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
|
|
||||||
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
|
|
||||||
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
|
|
||||||
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
|
|
||||||
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
|
|
||||||
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
|
|
||||||
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
|
||||||
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
|
||||||
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
|
|
||||||
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
|
|
||||||
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
|
|
||||||
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
|
||||||
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
|
||||||
| DELETE | `/api/v1/billing/subscription` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/health` | No | — | `{status, version}` |
|
|
||||||
| GET | `/api/v1/agents/catalog` | JWT | — | `AgentCatalogItem[]` |
|
|
||||||
| GET | `/api/v1/agents/local` | JWT | — | `LocalAgentConfigResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/local` | JWT | `LocalAgentConfigCreate` | `LocalAgentConfigResponse` |
|
|
||||||
| PUT | `/api/v1/agents/local/{id}` | JWT | `LocalAgentConfigUpdate` | `LocalAgentConfigResponse` |
|
|
||||||
| DELETE | `/api/v1/agents/local/{id}` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/agents/cloud` | JWT | — | `CloudAgentConfigResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/cloud` | JWT | `CloudAgentConfigCreate` | `CloudAgentConfigResponse` |
|
|
||||||
| PUT | `/api/v1/agents/cloud/{id}` | JWT | `CloudAgentConfigUpdate` | `CloudAgentConfigResponse` |
|
|
||||||
| DELETE | `/api/v1/agents/cloud/{id}` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/agents/runs` | JWT | `?agent_id&page&limit` | `AgentRunLogResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/{id}/run` | JWT | — | `{ok: true, run_id}` |
|
|
||||||
| POST | `/api/v1/agents/journey/start` | JWT | `{agent_type, data_types}` | `{session_id, message, done}` |
|
|
||||||
| POST | `/api/v1/agents/journey/message` | JWT | `{session_id, message}` | `{session_id, message, done, prompt_template?}` |
|
|
||||||
| GET | `/api/v1/oauth/{provider}/authorize` | JWT | — | `{authorization_url}` |
|
|
||||||
| GET | `/api/v1/oauth/{provider}/callback` | — | OAuth code | `{encrypted_token}` |
|
|
||||||
| WS | `/api/v1/ws/device` | JWT | `device_hello` (first frame) | Agent trigger + tool_call frames |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stack
|
|
||||||
|
|
||||||
| Layer | Technology |
|
|
||||||
|-------|-----------|
|
|
||||||
| Framework | FastAPI + Uvicorn |
|
|
||||||
| LLM | LangChain + langchain-openai |
|
|
||||||
| Auth | PyJWT + bcrypt + OAuth2 |
|
|
||||||
| Billing | stripe-python + Stripe Connect |
|
|
||||||
| Blob storage | boto3 (S3) |
|
|
||||||
| Vector store | Pinecone or Qdrant (configurable) |
|
|
||||||
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
|
||||||
| Rate limiting | slowapi |
|
|
||||||
| Cloud integrations | google-api-python-client, msgraph-sdk, msal |
|
|
||||||
| Agent scheduling | APScheduler |
|
|
||||||
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
|
|
||||||
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3 — New Files
|
|
||||||
|
|
||||||
| File | Purpose |
|
|
||||||
|---|---|
|
|
||||||
| `app/models.py` | Add `LocalAgentConfig`, `CloudAgentConfig`, `AgentRunLog` models |
|
|
||||||
| `app/schemas.py` | Add agent config schemas + WS agent frame types |
|
|
||||||
| `app/api/routes/agents.py` | Agent CRUD endpoints (catalog, local, cloud, runs, manual trigger) |
|
|
||||||
| `app/api/routes/agent_setup.py` | Chatbot Journey endpoints (start + message) |
|
|
||||||
| `app/api/routes/device_ws.py` | Persistent device WS endpoint (`/api/v1/ws/device`) |
|
|
||||||
| `app/api/routes/oauth.py` | OAuth authorize/callback for Gmail, Teams, Outlook |
|
|
||||||
| `app/core/agent_runner.py` | Agent run orchestration — local (WS file request) + cloud (API fetch) |
|
|
||||||
| `app/core/device_manager.py` | `DeviceConnectionManager` — tracks active Electron WS connections |
|
|
||||||
| `app/core/agent_scheduler.py` | Periodic scheduler for agent cron triggers |
|
|
||||||
| `app/integrations/gmail.py` | Gmail API client (fetch messages with filters) |
|
|
||||||
| `app/integrations/ms_graph.py` | MS Graph client for Outlook emails + Teams messages |
|
|
||||||
| `app/integrations/__init__.py` | Provider factory |
|
|
||||||
|
|
||||||
> **Full Phase 3 step-by-step plan:** See `AI_REFACTOR_PLAN.md` Phase 3 section.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Development Rules
|
|
||||||
|
|
||||||
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
|
|
||||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
|
|
||||||
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
|
|
||||||
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
|
||||||
5. **Type hints everywhere.** All functions have full type annotations.
|
|
||||||
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
|
||||||
7. **Structured logging.** JSON logs with request ID correlation.
|
|
||||||
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
|
|
||||||
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
|
||||||
10
README.md
10
README.md
@@ -83,7 +83,7 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
|
|||||||
## Key Features
|
## Key Features
|
||||||
|
|
||||||
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
||||||
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Checkpoints (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
||||||
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
||||||
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
||||||
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
||||||
@@ -449,7 +449,7 @@ The agent system uses a registry pattern with LangChain tool-calling agents powe
|
|||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
||||||
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
||||||
| **CheckpointAgent** | `checkpoint_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_checkpoints`, `create_checkpoint`, `update_checkpoint`, `delete_checkpoint` |
|
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
|
||||||
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
||||||
|
|
||||||
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
||||||
@@ -504,7 +504,7 @@ Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
|
|||||||
|
|
||||||
### Built-in Templates (6)
|
### Built-in Templates (6)
|
||||||
|
|
||||||
`tpl_task_agent_default`, `tpl_checkpoint_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
||||||
|
|
||||||
### Built-in Playbooks (2)
|
### Built-in Playbooks (2)
|
||||||
|
|
||||||
@@ -643,7 +643,7 @@ Source: `app/marketplace/`
|
|||||||
- Plugin ID must match `^[a-z0-9-]+$`
|
- Plugin ID must match `^[a-z0-9-]+$`
|
||||||
- Permissions must be from the allowed set only
|
- Permissions must be from the allowed set only
|
||||||
- No binary blobs in the manifest
|
- No binary blobs in the manifest
|
||||||
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:checkpoints`, `write:checkpoints`, `read:calendar`, `write:calendar`
|
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
|
||||||
- `get_pending(db)` — Lists plugins awaiting review.
|
- `get_pending(db)` — Lists plugins awaiting review.
|
||||||
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
||||||
|
|
||||||
@@ -734,7 +734,7 @@ adiuva-api/
|
|||||||
│ ├── agents/ # LLM-powered domain agents
|
│ ├── agents/ # LLM-powered domain agents
|
||||||
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
||||||
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
||||||
│ │ ├── checkpoint_agent.py # Milestones (4 tools)
|
│ │ ├── timeline_agent.py # Milestones (4 tools)
|
||||||
│ │ └── note_agent.py # Markdown notes (5 tools)
|
│ │ └── note_agent.py # Markdown notes (5 tools)
|
||||||
│ │
|
│ │
|
||||||
│ ├── core/ # Orchestration engine
|
│ ├── core/ # Orchestration engine
|
||||||
|
|||||||
@@ -37,12 +37,12 @@ _SEED_PLUGINS = [
|
|||||||
{
|
{
|
||||||
"id": "plugin-slack-notify",
|
"id": "plugin-slack-notify",
|
||||||
"name": "Slack Notifier",
|
"name": "Slack Notifier",
|
||||||
"description": "Post task and checkpoint updates to Slack channels.",
|
"description": "Post task and timeline updates to Slack channels.",
|
||||||
"version": "1.2.0",
|
"version": "1.2.0",
|
||||||
"author_name": "Adiuva",
|
"author_name": "Adiuva",
|
||||||
"category": "communication",
|
"category": "communication",
|
||||||
"price_cents": 499,
|
"price_cents": 499,
|
||||||
"permissions": json.dumps(["read:tasks", "read:checkpoints"]),
|
"permissions": json.dumps(["read:tasks", "read:timelines"]),
|
||||||
"status": "approved",
|
"status": "approved",
|
||||||
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
"install_count": 0,
|
"install_count": 0,
|
||||||
|
|||||||
144
alembic/versions/004_add_memory_tables.py
Normal file
144
alembic/versions/004_add_memory_tables.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""Add memory tables and user encryption_key column.
|
||||||
|
|
||||||
|
Memory tables:
|
||||||
|
memory_core — per-user key/value preferences (encrypted)
|
||||||
|
memory_associative — semantic memory with pgvector embedding (encrypted)
|
||||||
|
memory_episodic — session summaries (encrypted)
|
||||||
|
memory_proactive — behavioral patterns (encrypted)
|
||||||
|
|
||||||
|
Also adds encryption_key column to users table.
|
||||||
|
|
||||||
|
Revision ID: 004
|
||||||
|
Revises: 003
|
||||||
|
Create Date: 2026-03-08
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "004"
|
||||||
|
down_revision: Union[str, None] = "003"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enable pgvector extension (idempotent) ────────────────────────────────
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
|
||||||
|
# ── Add encryption_key to users ───────────────────────────────────────────
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("encryption_key", sa.String(64), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── memory_core ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_core",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("key", sa.String(255), nullable=False),
|
||||||
|
sa.Column("value_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_core_user_id", "memory_core", ["user_id"])
|
||||||
|
|
||||||
|
# ── memory_associative ────────────────────────────────────────────────────
|
||||||
|
# The embedding column uses pgvector's vector(1536) type.
|
||||||
|
op.create_table(
|
||||||
|
"memory_associative",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("content_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("entity_type", sa.String(100), nullable=True),
|
||||||
|
sa.Column("entity_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Add the pgvector column separately (not supported by generic sa types)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE memory_associative ADD COLUMN embedding vector(1536);"
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_associative_user_id", "memory_associative", ["user_id"])
|
||||||
|
# IVFFlat index for approximate nearest-neighbour search
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX ix_memory_associative_embedding "
|
||||||
|
"ON memory_associative USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── memory_episodic ───────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_episodic",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("summary_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("session_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_episodic_user_id", "memory_episodic", ["user_id"])
|
||||||
|
op.create_index("ix_memory_episodic_session_id", "memory_episodic", ["session_id"])
|
||||||
|
|
||||||
|
# ── memory_proactive ──────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_proactive",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("pattern_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("confidence", sa.Float, nullable=False, server_default="0.5"),
|
||||||
|
sa.Column("source", sa.String(50), nullable=False, server_default="inferred"),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_proactive_user_id", "memory_proactive", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("memory_proactive")
|
||||||
|
op.drop_table("memory_episodic")
|
||||||
|
op.drop_index("ix_memory_associative_embedding", "memory_associative")
|
||||||
|
op.drop_table("memory_associative")
|
||||||
|
op.drop_table("memory_core")
|
||||||
|
op.drop_column("users", "encryption_key")
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
"""add name and surname to users table
|
||||||
|
|
||||||
|
Revision ID: 818478c251dc
|
||||||
|
Revises: 004
|
||||||
|
Create Date: 2026-03-10 15:10:42.811947
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '818478c251dc'
|
||||||
|
down_revision: Union[str, None] = '004'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column('users', sa.Column('name', sa.String(length=100), nullable=True))
|
||||||
|
op.add_column('users', sa.Column('surname', sa.String(length=100), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('users', 'surname')
|
||||||
|
op.drop_column('users', 'name')
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Import all agent modules to trigger @registry.register decorators."""
|
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
|
||||||
|
|
||||||
from app.agents import checkpoint_agent, note_agent, project_agent, task_agent
|
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
@@ -1,127 +0,0 @@
|
|||||||
"""Checkpoint agent — project milestone management (list, create, update, delete)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
|
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n"
|
|
||||||
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
|
||||||
" - For update_checkpoint, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all checkpoints across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_checkpoints(project_id: str = "") -> str:
|
|
||||||
"""List checkpoints. Provide project_id to scope to a specific project."""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="select",
|
|
||||||
table="checkpoints",
|
|
||||||
filters={"projectId": project_id or None},
|
|
||||||
)
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No checkpoints found."
|
|
||||||
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
|
||||||
return f"Found {len(rows)} checkpoint(s):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def create_checkpoint(
|
|
||||||
project_id: str,
|
|
||||||
title: str,
|
|
||||||
date: int,
|
|
||||||
is_ai_suggested: int = 0,
|
|
||||||
is_approved: int = 0,
|
|
||||||
) -> str:
|
|
||||||
"""Create a project checkpoint (milestone).
|
|
||||||
project_id: REQUIRED UUID of the parent project
|
|
||||||
title: descriptive name for the milestone
|
|
||||||
date: Unix timestamp in milliseconds
|
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
|
||||||
is_approved: 0 until the user confirms
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="insert",
|
|
||||||
table="checkpoints",
|
|
||||||
data={
|
|
||||||
"projectId": project_id,
|
|
||||||
"title": title,
|
|
||||||
"date": date,
|
|
||||||
"isAiSuggested": is_ai_suggested,
|
|
||||||
"isApproved": is_approved,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
return f"Checkpoint created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def update_checkpoint(
|
|
||||||
checkpoint_id: str,
|
|
||||||
title: str = "",
|
|
||||||
date: int = -1,
|
|
||||||
is_approved: int = -1,
|
|
||||||
) -> str:
|
|
||||||
"""Update a checkpoint. Only pass fields that should change.
|
|
||||||
checkpoint_id: UUID of the checkpoint (required)
|
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
|
||||||
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
|
||||||
"""
|
|
||||||
updates: dict[str, Any] = {}
|
|
||||||
if title:
|
|
||||||
updates["title"] = title
|
|
||||||
if date != -1:
|
|
||||||
updates["date"] = date
|
|
||||||
if is_approved != -1:
|
|
||||||
updates["isApproved"] = is_approved
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="update",
|
|
||||||
table="checkpoints",
|
|
||||||
data={"id": checkpoint_id, "updates": updates},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
return f"Checkpoint updated: '{row['title']}' (id: {row['id']})"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def delete_checkpoint(checkpoint_id: str) -> str:
|
|
||||||
"""Delete a checkpoint permanently by its UUID."""
|
|
||||||
await execute_on_client(action="delete", table="checkpoints", data={"id": checkpoint_id})
|
|
||||||
return f"Checkpoint {checkpoint_id} deleted."
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class CheckpointAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "checkpoint_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages project checkpoints (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
@@ -1,30 +1,14 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
"""Note agent — tool definitions for Markdown note CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
from app.core.llm import embed
|
||||||
from app.core.llm import embed, get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
@@ -121,23 +105,4 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class NoteAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "note_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages notes: list, get, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_notes, get_note, create_note, update_note, delete_note]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -1,32 +1,13 @@
|
|||||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
"""Project agent — tool definitions for project lifecycle CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
|
||||||
"update, and archive projects in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: active, archived\n"
|
|
||||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
|
||||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
|
||||||
" derive it from context data — do not fabricate content\n"
|
|
||||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
|
||||||
" user wants a complete cross-client view including archived projects\n"
|
|
||||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
|
||||||
" list_projects if you only have a project name\n"
|
|
||||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
|
||||||
" only call delete_project when the user explicitly confirms deletion."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_projects(
|
async def list_projects(
|
||||||
@@ -136,30 +117,4 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class ProjectAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "project_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [
|
|
||||||
list_projects,
|
|
||||||
list_all_projects,
|
|
||||||
get_project,
|
|
||||||
create_project,
|
|
||||||
update_project,
|
|
||||||
delete_project,
|
|
||||||
]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -1,35 +1,14 @@
|
|||||||
"""Task agent — full CRUD for tasks and task comments."""
|
"""Task agent — tool definitions for task and task comment CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a task management assistant for a project workspace.\n"
|
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: todo, in_progress, done\n"
|
|
||||||
" - priority must be one of: high, medium, low\n"
|
|
||||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
|
||||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
|
||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
|
||||||
" did not explicitly request; 0 otherwise\n"
|
|
||||||
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -220,35 +199,4 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
return f"Comment {comment_id} deleted."
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
# ── Agent ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class TaskAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "task_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [
|
|
||||||
list_tasks,
|
|
||||||
create_task,
|
|
||||||
update_task,
|
|
||||||
delete_task,
|
|
||||||
list_tasks_due_today,
|
|
||||||
list_task_comments,
|
|
||||||
add_task_comment,
|
|
||||||
delete_task_comment,
|
|
||||||
]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
92
app/agents/timeline_agent.py
Normal file
92
app/agents/timeline_agent.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Timeline agent — tool definitions for project milestone CRUD."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="timelines",
|
||||||
|
filters={"projectId": project_id or None},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No timelines found."
|
||||||
|
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_timeline(
|
||||||
|
project_id: str,
|
||||||
|
title: str,
|
||||||
|
date: int,
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a project timeline (milestone).
|
||||||
|
project_id: REQUIRED UUID of the parent project
|
||||||
|
title: descriptive name for the milestone
|
||||||
|
date: Unix timestamp in milliseconds
|
||||||
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="timelines",
|
||||||
|
data={
|
||||||
|
"projectId": project_id,
|
||||||
|
"title": title,
|
||||||
|
"date": date,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_timeline(
|
||||||
|
timeline_id: str,
|
||||||
|
title: str = "",
|
||||||
|
date: int = -1,
|
||||||
|
is_approved: int = -1,
|
||||||
|
) -> str:
|
||||||
|
"""Update a timeline. Only pass fields that should change.
|
||||||
|
timeline_id: UUID of the timeline (required)
|
||||||
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if date != -1:
|
||||||
|
updates["date"] = date
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="timelines",
|
||||||
|
data={"id": timeline_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_timeline(timeline_id: str) -> str:
|
||||||
|
"""Delete a timeline permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||||
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -55,11 +55,23 @@ async def get_current_user(
|
|||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup — subscription row is the authoritative source.
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
from app.models import Subscription # noqa: PLC0415
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str = result.scalar_one_or_none() or "free"
|
tier: str = result.scalar_one_or_none() or "free"
|
||||||
|
|
||||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
# Fetch name/surname from user row.
|
||||||
|
user_result = await db.execute(
|
||||||
|
select(User.name, User.surname).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user_id,
|
||||||
|
email=email,
|
||||||
|
name=user_row.name if user_row else None,
|
||||||
|
surname=user_row.surname if user_row else None,
|
||||||
|
tier=tier,
|
||||||
|
) # type: ignore[arg-type]
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ and produce a detailed prompt_template that a separate AI will use as its instru
|
|||||||
|
|
||||||
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
1. The type and format of the source content.
|
1. The type and format of the source content.
|
||||||
2. Which data types to extract: tasks, notes, checkpoints, and/or projects.
|
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
||||||
3. How fields should be mapped (e.g. email subject → task title).
|
3. How fields should be mapped (e.g. email subject → task title).
|
||||||
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
5. Any special handling, date extraction, or exclusions.
|
5. Any special handling, date extraction, or exclusions.
|
||||||
@@ -121,7 +121,7 @@ these exact markers on their own lines:
|
|||||||
|
|
||||||
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
||||||
and must return a JSON array of records in this shape:
|
and must return a JSON array of records in this shape:
|
||||||
[{{ "table": "<tasks|notes|checkpoints|projects>", "data": {{ <field: value> }} }}, ...]
|
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
||||||
|
|
||||||
Rules for the generated template:
|
Rules for the generated template:
|
||||||
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import uuid
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -65,6 +66,8 @@ def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
|||||||
class _RegisterRequest(BaseModel):
|
class _RegisterRequest(BaseModel):
|
||||||
email: str
|
email: str
|
||||||
password: str
|
password: str
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class _LoginRequest(BaseModel):
|
class _LoginRequest(BaseModel):
|
||||||
@@ -92,8 +95,11 @@ async def register(
|
|||||||
user = User(
|
user = User(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
email=body.email,
|
email=body.email,
|
||||||
|
name=body.name,
|
||||||
|
surname=body.surname,
|
||||||
password_hash=_hash_password(body.password),
|
password_hash=_hash_password(body.password),
|
||||||
tier="free",
|
tier="free",
|
||||||
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
)
|
)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
await db.flush() # get user.id without committing
|
await db.flush() # get user.id without committing
|
||||||
@@ -191,7 +197,39 @@ async def refresh(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateProfileRequest(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserProfile)
|
@router.get("/me", response_model=UserProfile)
|
||||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||||
"""Return the profile for the authenticated user."""
|
"""Return the profile for the authenticated user."""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me", response_model=UserProfile)
|
||||||
|
async def update_profile(
|
||||||
|
body: _UpdateProfileRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update the authenticated user's name and surname."""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
if body.name is not None:
|
||||||
|
user.name = body.name
|
||||||
|
if body.surname is not None:
|
||||||
|
user.surname = body.surname
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user.id,
|
||||||
|
email=user.email,
|
||||||
|
name=user.name,
|
||||||
|
surname=user.surname,
|
||||||
|
tier=current_user.tier,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,78 +1,42 @@
|
|||||||
"""Chat routes: POST /chat and WebSocket /chat/stream."""
|
"""Chat routes: POST /chat (REST fallback).
|
||||||
|
|
||||||
|
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
from fastapi import APIRouter, Depends
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.core.deep_agent import run_home
|
||||||
from app.core.orchestrator import orchestrate, orchestrate_stream
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.db import async_session
|
||||||
|
from app.schemas import ChatRequest, ChatResponse, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def chat(
|
async def chat(
|
||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Route a chat message through the orchestrator.
|
"""Route a chat message through the Home deep agent (non-streaming)."""
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(current_user.id, body.message)
|
||||||
|
|
||||||
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
context = {
|
||||||
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
**body.context.model_dump(),
|
||||||
"""
|
**memory_context,
|
||||||
result = await orchestrate(body)
|
}
|
||||||
|
|
||||||
|
response_text = await run_home(
|
||||||
|
user_id=current_user.id,
|
||||||
|
message=body.message,
|
||||||
|
context=context,
|
||||||
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
|
result = ChatResponse(response=response_text)
|
||||||
return JSONResponse(content=result.model_dump())
|
return JSONResponse(content=result.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/stream")
|
|
||||||
async def chat_stream(websocket: WebSocket) -> None:
|
|
||||||
"""Streaming chat via WebSocket.
|
|
||||||
|
|
||||||
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
|
|
||||||
|
|
||||||
Protocol:
|
|
||||||
1. Client sends ``ChatRequest`` as the first JSON text frame.
|
|
||||||
2. Server streams response text chunks.
|
|
||||||
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
|
|
||||||
4. Server pings every 30 s to keep the connection alive.
|
|
||||||
"""
|
|
||||||
# Authenticate before accepting the connection
|
|
||||||
token = websocket.query_params.get("token", "")
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
if not user_id:
|
|
||||||
raise JWTError("missing sub")
|
|
||||||
except JWTError:
|
|
||||||
await websocket.close(code=1008) # 1008 = Policy Violation
|
|
||||||
return
|
|
||||||
|
|
||||||
await websocket.accept()
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw = await websocket.receive_text()
|
|
||||||
body = ChatRequest.model_validate_json(raw)
|
|
||||||
|
|
||||||
async def _heartbeat() -> None:
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
||||||
await websocket.send_text(json.dumps({"ping": True}))
|
|
||||||
|
|
||||||
heartbeat_task = asyncio.create_task(_heartbeat())
|
|
||||||
try:
|
|
||||||
async for chunk in orchestrate_stream(body):
|
|
||||||
await websocket.send_text(chunk)
|
|
||||||
finally:
|
|
||||||
heartbeat_task.cancel()
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -33,14 +33,19 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.core.deep_agent import run_home_stream, run_floating_stream
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
from app.schemas import WsFrameType
|
from app.schemas import WsFrameType
|
||||||
@@ -173,6 +178,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: agent_complete missing run_id from user=%s", user_id
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.home_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_home_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.floating_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -183,6 +198,135 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
|
async def _executor(payload: dict) -> dict:
|
||||||
|
payload["type"] = WsFrameType.tool_call
|
||||||
|
call_id = payload["id"]
|
||||||
|
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
|
||||||
|
await websocket.send_text(json.dumps(payload))
|
||||||
|
future = device_manager.create_pending_call(user_id, call_id)
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
|
||||||
|
call_id, payload.get("action"), user_id,
|
||||||
|
)
|
||||||
|
# Clean up the pending future so it doesn't leak
|
||||||
|
conn = device_manager._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.pending_calls.pop(call_id, None)
|
||||||
|
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
|
||||||
|
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
|
||||||
|
call_id, type(result).__name__,
|
||||||
|
list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
|
if result is None:
|
||||||
|
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
|
||||||
|
return result
|
||||||
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_home_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
|
||||||
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
try:
|
||||||
|
event_stream = run_home_stream(
|
||||||
|
user_id, message, context, db_session_factory=async_session
|
||||||
|
)
|
||||||
|
formatter = HomeFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: home_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Memory: store episode after response ──────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_floating_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
scope: dict = frame.get("scope", {})
|
||||||
|
|
||||||
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
|
|
||||||
|
context: dict = {"scope": scope, **memory_context}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
try:
|
||||||
|
event_stream = run_floating_stream(
|
||||||
|
user_id, message, context, scope=scope,
|
||||||
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
|
formatter = FloatingFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: floating_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Memory: store episode after response ──────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.core.execution_plan import plan_cache
|
|
||||||
from app.schemas import ExecutionPlan, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/plans", tags=["plans"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook", response_model=list[ExecutionPlan])
|
|
||||||
async def list_playbooks(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached execution plan playbooks for the authenticated user.
|
|
||||||
|
|
||||||
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
|
||||||
"""
|
|
||||||
return plan_cache.get_all_playbooks()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
|
||||||
async def get_playbook(
|
|
||||||
plan_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> ExecutionPlan:
|
|
||||||
"""Return a specific execution plan playbook by ID."""
|
|
||||||
plan = plan_cache.get_plan(plan_id)
|
|
||||||
if plan is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Plan not found: {plan_id}",
|
|
||||||
)
|
|
||||||
return plan
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
@@ -26,17 +26,35 @@ class Settings(BaseSettings):
|
|||||||
OPENAI_API_KEY: str = ""
|
OPENAI_API_KEY: str = ""
|
||||||
ANTHROPIC_API_KEY: str = ""
|
ANTHROPIC_API_KEY: str = ""
|
||||||
GOOGLE_API_KEY: str = ""
|
GOOGLE_API_KEY: str = ""
|
||||||
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
||||||
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
# GitHub Copilot OAuth token storage directory.
|
||||||
|
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||||
|
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||||
|
GITHUB_COPILOT_TOKEN_DIR: str = ""
|
||||||
|
|
||||||
|
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
|
||||||
|
GMAIL_CLIENT_ID: str = ""
|
||||||
|
GMAIL_CLIENT_SECRET: str = ""
|
||||||
|
MS_CLIENT_ID: str = ""
|
||||||
|
MS_CLIENT_SECRET: str = ""
|
||||||
|
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||||
|
MS_TENANT_ID: str = "common"
|
||||||
|
|
||||||
|
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||||
|
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||||
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
class Config:
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
env_file = ".env"
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -1,137 +0,0 @@
|
|||||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
|
||||||
"""Common base for all agents."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
user_id: str = "",
|
|
||||||
shared_memory: dict[str, Any] | None = None,
|
|
||||||
vector_store_context: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.user_id = user_id
|
|
||||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
|
||||||
self.vector_store_context: list[str] = vector_store_context or []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_name(self) -> str: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_description(self) -> str: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def skills(self) -> list[str]:
|
|
||||||
"""Override in subclasses to advertise capabilities."""
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(BaseAgent):
|
|
||||||
"""Base class for LLM-powered chat agents."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
"""Process a user query and return a text response."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
"""Return LangChain tool definitions available to this agent."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def _tool_loop(
|
|
||||||
self,
|
|
||||||
llm: Any,
|
|
||||||
messages: list[Any],
|
|
||||||
tools: list[Any],
|
|
||||||
max_iter: int = 5,
|
|
||||||
) -> str:
|
|
||||||
"""Shared tool-calling loop.
|
|
||||||
|
|
||||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
|
||||||
requesting tool calls or *max_iter* is reached, and returns the
|
|
||||||
final text response.
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
||||||
|
|
||||||
for _ in range(max_iter):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
return str(response.content)
|
|
||||||
|
|
||||||
# Execute each requested tool call
|
|
||||||
tool_map = {t.name: t for t in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
|
||||||
if tool_fn is None:
|
|
||||||
result = f"Unknown tool: {call['name']}"
|
|
||||||
else:
|
|
||||||
result = await tool_fn.ainvoke(call["args"])
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exhausted iterations — ask model for a final answer without tools
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
return str(response.content)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRegistry:
|
|
||||||
"""Singleton registry for ChatAgent subclasses."""
|
|
||||||
|
|
||||||
_instance: AgentRegistry | None = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._agents: dict[str, type[ChatAgent]] = {}
|
|
||||||
|
|
||||||
def __new__(cls) -> AgentRegistry:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
cls._instance._agents = {}
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
# ── public API ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
|
||||||
"""Class decorator — registers an agent by its name."""
|
|
||||||
instance = agent_class()
|
|
||||||
name = instance.get_name()
|
|
||||||
self._agents[name] = agent_class
|
|
||||||
return agent_class
|
|
||||||
|
|
||||||
def get(self, name: str) -> ChatAgent:
|
|
||||||
"""Return a fresh instance of the named agent."""
|
|
||||||
cls = self._agents.get(name)
|
|
||||||
if cls is None:
|
|
||||||
raise KeyError(f"Agent not found: {name}")
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
def list_agents(self) -> list[dict[str, str]]:
|
|
||||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
|
||||||
result: list[dict[str, str]] = []
|
|
||||||
for cls in self._agents.values():
|
|
||||||
inst = cls()
|
|
||||||
result.append(
|
|
||||||
{"name": inst.get_name(), "description": inst.get_description()}
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def call_agent(
|
|
||||||
self, name: str, query: str, context: dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Instantiate the named agent and call its ``handle`` method."""
|
|
||||||
agent = self.get(name)
|
|
||||||
return await agent.handle(query, context)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
registry = AgentRegistry()
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Agent run orchestrator.
|
"""Agent run manager.
|
||||||
|
|
||||||
Drives two agent types:
|
Drives two agent types:
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from croniter import croniter
|
from croniter import croniter
|
||||||
@@ -53,7 +53,7 @@ _INSERT_TIMEOUT: int = 30
|
|||||||
# ── Allowed tables & extraction schema hints ───────────────────────────────
|
# ── Allowed tables & extraction schema hints ───────────────────────────────
|
||||||
|
|
||||||
_ALLOWED_TABLES: frozenset[str] = frozenset(
|
_ALLOWED_TABLES: frozenset[str] = frozenset(
|
||||||
{"tasks", "notes", "checkpoints", "projects", "taskComments"}
|
{"tasks", "notes", "timelines", "projects", "taskComments"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Field descriptions fed to the extraction LLM as concise schema references.
|
# Field descriptions fed to the extraction LLM as concise schema references.
|
||||||
@@ -65,7 +65,7 @@ _TABLE_SCHEMAS: dict[str, str] = {
|
|||||||
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
|
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
|
||||||
),
|
),
|
||||||
"notes": "title (str, required), content (str, markdown), projectId (str)",
|
"notes": "title (str, required), content (str, markdown), projectId (str)",
|
||||||
"checkpoints": (
|
"timelines": (
|
||||||
"title (str, required), projectId (str, required), date (ms timestamp int)"
|
"title (str, required), projectId (str, required), date (ms timestamp int)"
|
||||||
),
|
),
|
||||||
"projects": "name (str, required), clientId (str)",
|
"projects": "name (str, required), clientId (str)",
|
||||||
@@ -383,7 +383,10 @@ async def run_local_agent(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent runner (stub) ───────────────────────────────────────────────
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Default lookback window when an agent has never run before.
|
||||||
|
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||||||
|
|
||||||
|
|
||||||
async def run_cloud_agent(
|
async def run_cloud_agent(
|
||||||
@@ -392,26 +395,199 @@ async def run_cloud_agent(
|
|||||||
run_log: AgentRunLog,
|
run_log: AgentRunLog,
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a cloud connector agent run.
|
"""Execute a cloud connector agent run end-to-end.
|
||||||
|
|
||||||
.. note::
|
Steps:
|
||||||
This is a **stub** — provider integrations (Gmail, Teams, Outlook)
|
|
||||||
are implemented in Step 3.6. The run is immediately marked as an
|
1. Verify the user's device is online — results are pushed to Electron
|
||||||
error with an informative message.
|
via WS tool-call frames. If no device is connected, abort.
|
||||||
|
2. Decrypt the stored OAuth token from ``config.oauth_token_encrypted``.
|
||||||
|
3. Instantiate the provider client (Gmail or MS Graph).
|
||||||
|
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
|
||||||
|
the first run) applying ``config.filter_config`` filters.
|
||||||
|
5. For each message/email call ``_extract_items_from_content`` with
|
||||||
|
``config.prompt_template`` to get structured ``{table, data}`` items.
|
||||||
|
6. Push each item to Electron as an ``insert`` tool-call.
|
||||||
|
7. If the provider refreshed its access token, re-encrypt and write it
|
||||||
|
back to ``config.oauth_token_encrypted``.
|
||||||
|
8. Persist the run outcome via ``_finalize_run``.
|
||||||
"""
|
"""
|
||||||
|
run_id = run_log.id
|
||||||
|
|
||||||
|
# ── 1. Device online check ─────────────────────────────────────────
|
||||||
|
if not device_mgr.is_online(user_id):
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: skip cloud run=%s — no device online for user=%s",
|
||||||
|
run_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=["No connected device — cloud agent results cannot be delivered"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 2. Decrypt OAuth token ─────────────────────────────────────────
|
||||||
|
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||||||
|
|
||||||
|
if not config.oauth_token_encrypted:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Instantiate provider client ────────────────────────────────
|
||||||
|
try:
|
||||||
|
provider = get_provider(config.provider, credentials_info)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[str(exc)],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 4. Fetch messages ─────────────────────────────────────────────
|
||||||
|
since: datetime | None = config.last_run_at
|
||||||
|
if since is None:
|
||||||
|
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||||||
|
if since.tzinfo is None:
|
||||||
|
since = since.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.provider == "gmail":
|
||||||
|
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "outlook":
|
||||||
|
raw_messages = await provider.fetch_emails( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "teams":
|
||||||
|
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_messages = []
|
||||||
|
except RuntimeError as exc:
|
||||||
|
logger.error(
|
||||||
|
"agent_runner: provider fetch failed for cloud agent %s: %s",
|
||||||
|
config.id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Provider fetch failed: {exc}"],
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: cloud agent %s (provider=%s) for user=%s — pending Step 3.6",
|
"agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
|
||||||
config.id,
|
config.id,
|
||||||
|
len(raw_messages),
|
||||||
config.provider,
|
config.provider,
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
||||||
|
for msg in raw_messages:
|
||||||
|
content_text = msg.as_text
|
||||||
|
if not content_text:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
try:
|
||||||
|
extracted = await _extract_items_from_content(
|
||||||
|
config.prompt_template, content_text, config.data_types
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM extraction error for message {msg.id!r}: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for item in extracted:
|
||||||
|
try:
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
user_id, item["table"], item["data"], device_mgr
|
||||||
|
)
|
||||||
|
if result.get("error"):
|
||||||
|
errors.append(
|
||||||
|
f"Insert failed ({item['table']}, msg={msg.id!r}): {result['error']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
items_created += 1
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append(
|
||||||
|
f"Timed out awaiting insert ack ({item['table']}, msg={msg.id!r})"
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Insert error ({item['table']}, msg={msg.id!r}): {exc}")
|
||||||
|
|
||||||
|
# ── 7. Persist refreshed token (if any) ───────────────────────────
|
||||||
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
|
if refreshed:
|
||||||
|
try:
|
||||||
|
new_encrypted = encrypt_token(refreshed)
|
||||||
|
async with async_session() as db:
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||||||
|
)
|
||||||
|
cfg_row = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg_row:
|
||||||
|
cfg_row.oauth_token_encrypted = new_encrypted
|
||||||
|
await db.commit()
|
||||||
|
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to persist refreshed token for agent %s: %s", config.id, exc)
|
||||||
|
|
||||||
|
# ── 8. Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_created == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
await _finalize_run(
|
await _finalize_run(
|
||||||
run_log,
|
run_log,
|
||||||
status="error",
|
status=final_status,
|
||||||
errors=[
|
items_processed=items_processed,
|
||||||
f"Cloud provider integrations for '{config.provider}' are not yet "
|
items_created=items_created,
|
||||||
"implemented. This feature arrives in Step 3.6."
|
errors=errors,
|
||||||
],
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
|
run_id,
|
||||||
|
final_status,
|
||||||
|
items_processed,
|
||||||
|
items_created,
|
||||||
|
len(errors),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -519,13 +695,21 @@ async def _finalize_run(
|
|||||||
managed.errors = errors or []
|
managed.errors = errors or []
|
||||||
managed.completed_at = now
|
managed.completed_at = now
|
||||||
|
|
||||||
if update_config_last_run and config_id and config_type == "local":
|
if update_config_last_run and config_id:
|
||||||
cfg_result = await db.execute(
|
if config_type == "local":
|
||||||
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
cfg_result = await db.execute(
|
||||||
)
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||||
cfg = cfg_result.scalar_one_or_none()
|
)
|
||||||
if cfg:
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
cfg.last_run_at = now
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
elif config_type == "cloud":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|||||||
489
app/core/deep_agent.py
Normal file
489
app/core/deep_agent.py
Normal file
@@ -0,0 +1,489 @@
|
|||||||
|
"""Deep Agent — ``create_deep_agent`` supervisors for home and floating modes.
|
||||||
|
|
||||||
|
Two supervisor graphs (via ``deepagents.create_deep_agent``):
|
||||||
|
* **HomeSupervisor** — gathers data from multiple domains, presents
|
||||||
|
structured overview with entity/chart tags.
|
||||||
|
* **FloatingSupervisor** — focused, scoped assistant for a single entity/domain.
|
||||||
|
|
||||||
|
Each supervisor delegates to four sub-agents (task, project, note, timeline)
|
||||||
|
via the built-in ``task`` tool provided by ``SubAgentMiddleware``.
|
||||||
|
The sub-agents talk to Electron via ``execute_on_client``.
|
||||||
|
|
||||||
|
Built-in middleware provides: todo-list tracking, virtual filesystem,
|
||||||
|
automatic context summarisation, prompt-caching, and tool-call patching.
|
||||||
|
|
||||||
|
Streaming uses ``astream(stream_mode=["messages", "updates"])`` so that
|
||||||
|
callers can sniff:
|
||||||
|
* ``("messages", (token, metadata))`` — text tokens for streaming
|
||||||
|
* ``("updates", ...)`` — tool call results for mutations
|
||||||
|
|
||||||
|
An ``update_core_memory`` tool is available to both supervisors for
|
||||||
|
persisting user preferences mid-conversation (MemGPT-style).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
from deepagents import create_deep_agent
|
||||||
|
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.core.ws_context import (
|
||||||
|
clear_tool_result_collector,
|
||||||
|
set_tool_result_collector,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Sub-agent tool imports ────────────────────────────────────────────
|
||||||
|
|
||||||
|
from app.agents.task_agent import ( # noqa: E402
|
||||||
|
add_task_comment,
|
||||||
|
create_task,
|
||||||
|
delete_task,
|
||||||
|
delete_task_comment,
|
||||||
|
list_task_comments,
|
||||||
|
list_tasks,
|
||||||
|
list_tasks_due_today,
|
||||||
|
update_task,
|
||||||
|
)
|
||||||
|
from app.agents.note_agent import ( # noqa: E402
|
||||||
|
create_note,
|
||||||
|
delete_note,
|
||||||
|
get_note,
|
||||||
|
list_notes,
|
||||||
|
update_note,
|
||||||
|
)
|
||||||
|
from app.agents.project_agent import ( # noqa: E402
|
||||||
|
create_project,
|
||||||
|
delete_project,
|
||||||
|
get_project,
|
||||||
|
list_all_projects,
|
||||||
|
list_projects,
|
||||||
|
update_project,
|
||||||
|
)
|
||||||
|
from app.agents.timeline_agent import ( # noqa: E402
|
||||||
|
create_timeline,
|
||||||
|
delete_timeline,
|
||||||
|
list_timelines,
|
||||||
|
update_timeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Sub-agent definitions ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
_TASK_TOOLS = [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
|
|
||||||
|
_NOTE_TOOLS = [list_notes, get_note, create_note, update_note, delete_note]
|
||||||
|
|
||||||
|
_PROJECT_TOOLS = [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
|
|
||||||
|
_TIMELINE_TOOLS = [list_timelines, create_timeline, update_timeline, delete_timeline]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_subagent_specs() -> list[dict[str, Any]]:
|
||||||
|
"""Return SubAgent dicts for the four workspace domains.
|
||||||
|
|
||||||
|
Each dict follows the ``deepagents`` ``SubAgent`` TypedDict:
|
||||||
|
name, description, system_prompt, tools, model
|
||||||
|
The model and middleware are filled in by ``create_deep_agent`` automatically.
|
||||||
|
"""
|
||||||
|
llm = get_llm()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "task_agent",
|
||||||
|
"description": (
|
||||||
|
"Manages tasks and comments: list, create, update, delete, "
|
||||||
|
"due-today, and comments. Use when the user asks about tasks, "
|
||||||
|
"to-dos, assignments, deadlines, or anything task-related."
|
||||||
|
),
|
||||||
|
"system_prompt": (
|
||||||
|
"You are a task management assistant. You create, update, list, "
|
||||||
|
"and track tasks and their comments.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: todo, in_progress, done\n"
|
||||||
|
" - priority must be one of: high, medium, low\n"
|
||||||
|
" - due_date is a Unix timestamp in milliseconds\n"
|
||||||
|
" - assignees is a JSON-encoded array of strings\n"
|
||||||
|
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
||||||
|
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Always confirm the action in plain, user-friendly language."
|
||||||
|
),
|
||||||
|
"tools": _TASK_TOOLS
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "note_agent",
|
||||||
|
"description": (
|
||||||
|
"Manages notes: list, get, create, update, delete. "
|
||||||
|
"Use when the user asks about notes, documents, or written content."
|
||||||
|
),
|
||||||
|
"system_prompt": (
|
||||||
|
"You are a note-taking assistant. You help users create, retrieve, "
|
||||||
|
"update, and delete Markdown notes in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - content is always Markdown; preserve formatting when updating\n"
|
||||||
|
" - When updating, call get_note first if you need to read existing "
|
||||||
|
"content before appending or replacing sections\n"
|
||||||
|
" - Do not fabricate note content."
|
||||||
|
),
|
||||||
|
"tools": _NOTE_TOOLS
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "project_agent",
|
||||||
|
"description": (
|
||||||
|
"Manages projects: list, get, create, update, archive, delete. "
|
||||||
|
"Use when the user asks about projects, workspaces, or project status."
|
||||||
|
),
|
||||||
|
"system_prompt": (
|
||||||
|
"You are a project management assistant. You help users create, "
|
||||||
|
"find, update, and archive projects.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: active, archived\n"
|
||||||
|
" - Prefer archiving over deletion\n"
|
||||||
|
" - ai_summary is populated only when the user asks for a summary."
|
||||||
|
),
|
||||||
|
"tools": _PROJECT_TOOLS
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "timeline_agent",
|
||||||
|
"description": (
|
||||||
|
"Manages project timelines and milestones: list, create, update, "
|
||||||
|
"delete. Use when the user asks about timelines, milestones, "
|
||||||
|
"deadlines, or project scheduling."
|
||||||
|
),
|
||||||
|
"system_prompt": (
|
||||||
|
"You are a project timeline assistant. Timelines are milestone "
|
||||||
|
"dates that track progress on a project.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - project_id is REQUIRED for every create\n"
|
||||||
|
" - date is a Unix timestamp in milliseconds\n"
|
||||||
|
" - For update_timeline, use -1 for integer fields you do not "
|
||||||
|
"want to change."
|
||||||
|
),
|
||||||
|
"tools": _TIMELINE_TOOLS
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Update core memory tool ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_update_core_memory_tool(user_id: str, db_session_factory):
|
||||||
|
"""Create a tool that persists a key/value preference in core memory."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_core_memory(key: str, value: str) -> str:
|
||||||
|
"""Save a user preference or fact to long-term core memory.
|
||||||
|
key: short label for the memory (e.g. 'preferred_language', 'timezone')
|
||||||
|
value: the value to remember
|
||||||
|
Use this when the user states a preference or fact worth remembering.
|
||||||
|
"""
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
|
||||||
|
async with db_session_factory() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, key, value)
|
||||||
|
return f"Remembered: {key} = {value}"
|
||||||
|
|
||||||
|
return update_core_memory
|
||||||
|
|
||||||
|
|
||||||
|
# ── System prompts ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_HOME_SYSTEM = (
|
||||||
|
"You are Adiuva, a smart workspace assistant on the Home dashboard.\n"
|
||||||
|
"Your job is to help the user by gathering data from their workspace and "
|
||||||
|
"presenting a comprehensive overview.\n\n"
|
||||||
|
"You have sub-agents (task_agent, note_agent, project_agent, "
|
||||||
|
"timeline_agent) accessible via the `task` tool. Delegate to "
|
||||||
|
"the appropriate sub-agent(s) based on the user's request. You can call "
|
||||||
|
"multiple sub-agents in parallel if needed.\n\n"
|
||||||
|
"You also have an update_core_memory tool — use it when the user states "
|
||||||
|
"a preference or important fact worth remembering long-term.\n\n"
|
||||||
|
"IMPORTANT: You do NOT have direct access to workspace data. Always "
|
||||||
|
"delegate to your subagents using the task() tool. Do not attempt to "
|
||||||
|
"answer workspace queries yourself — the subagents have the tools to "
|
||||||
|
"fetch and modify data. You can call multiple subagents in parallel "
|
||||||
|
"when the request spans multiple domains.\n\n"
|
||||||
|
"## Entity References\n"
|
||||||
|
"When your response mentions specific workspace entities, embed them "
|
||||||
|
"inline using entity tags so the UI can render interactive components.\n"
|
||||||
|
"Format: <type>[comma-separated UUIDs]</type>\n"
|
||||||
|
"Supported types: task, project, note, timeline\n\n"
|
||||||
|
"Example response:\n"
|
||||||
|
" Here is your project:\n"
|
||||||
|
" <project>[abc-123-def]</project>\n"
|
||||||
|
" It has these pending tasks:\n"
|
||||||
|
" <task>[def-456,ghi-789]</task>\n\n"
|
||||||
|
"IMPORTANT: Only include IDs of entities that are directly relevant to "
|
||||||
|
"the user's question. Do NOT dump all entity IDs returned by a tool — "
|
||||||
|
"filter to only the ones the user asked about or that matter for the answer.\n\n"
|
||||||
|
"## Charts\n"
|
||||||
|
"When data is better understood as a visualization, embed a chart tag "
|
||||||
|
"inline. The frontend renders it using shadcn/ui Recharts components.\n"
|
||||||
|
"Format: <chart>{{JSON}}</chart>\n\n"
|
||||||
|
"JSON shape:\n"
|
||||||
|
' {{"chartType":"<type>","title":"...","data":[...],"config":{{...}}}}\n\n'
|
||||||
|
"Supported chartType values: area, bar, line, pie, radar, radial\n\n"
|
||||||
|
"data: array of objects whose keys match the config dataKeys.\n"
|
||||||
|
"config: {{ dataKey: {{ label, color }} }} — follows shadcn ChartConfig.\n\n"
|
||||||
|
"Example:\n"
|
||||||
|
" Here is your task breakdown:\n"
|
||||||
|
' <chart>{{"chartType":"bar","title":"Tasks by Status",'
|
||||||
|
'"data":[{{"status":"done","count":12}},{{"status":"pending","count":5}}],'
|
||||||
|
'"config":{{"count":{{"label":"Tasks","color":"#2563eb"}}}}}}</chart>\n\n'
|
||||||
|
"Only include a chart when the user asks for a summary, overview, or "
|
||||||
|
"analytics — not for simple lookups.\n\n"
|
||||||
|
"Memory context:\n{memory_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_SYSTEM = (
|
||||||
|
"You are Adiuva, a focused workspace assistant in the floating panel.\n"
|
||||||
|
"The user is currently working in the '{scope_type}' section"
|
||||||
|
"{scope_detail}.\n\n"
|
||||||
|
"You have sub-agents (task_agent, note_agent, project_agent, "
|
||||||
|
"timeline_agent) accessible via the `task` tool. Focus your "
|
||||||
|
"help on the user's current scope, but you can use other sub-agents "
|
||||||
|
"if the request requires it.\n\n"
|
||||||
|
"You also have an update_core_memory tool — use it when the user states "
|
||||||
|
"a preference or important fact worth remembering long-term.\n\n"
|
||||||
|
"IMPORTANT: You do NOT have direct access to workspace data. Always "
|
||||||
|
"delegate to your subagents using the task() tool. Do not attempt to "
|
||||||
|
"answer workspace queries yourself — the subagents have the tools to "
|
||||||
|
"fetch and modify data.\n\n"
|
||||||
|
"Provide direct, conversational responses.\n\n"
|
||||||
|
"Memory context:\n{memory_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_memory_context(memory: dict[str, Any]) -> str:
|
||||||
|
"""Format the memory dict into a readable string for the system prompt."""
|
||||||
|
if not memory:
|
||||||
|
return "(no memory available)"
|
||||||
|
parts = []
|
||||||
|
if memory.get("core_memory"):
|
||||||
|
parts.append("Preferences: " + json.dumps(memory["core_memory"]))
|
||||||
|
if memory.get("associative_memory"):
|
||||||
|
parts.append("Related memories: " + "; ".join(memory["associative_memory"][:3]))
|
||||||
|
if memory.get("episodic_memory"):
|
||||||
|
parts.append("Recent sessions: " + "; ".join(memory["episodic_memory"][:3]))
|
||||||
|
if memory.get("proactive_hints"):
|
||||||
|
parts.append("Patterns: " + "; ".join(memory["proactive_hints"][:3]))
|
||||||
|
return "\n".join(parts) if parts else "(no memory available)"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Graph builders ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def build_home_graph(
|
||||||
|
user_id: str,
|
||||||
|
memory_context: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
):
|
||||||
|
"""Build the Home supervisor graph."""
|
||||||
|
subagent_specs = _make_subagent_specs()
|
||||||
|
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
|
||||||
|
|
||||||
|
prompt = _HOME_SYSTEM.format(
|
||||||
|
memory_context=_format_memory_context(memory_context),
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_deep_agent(
|
||||||
|
model=get_llm(),
|
||||||
|
tools=[memory_tool],
|
||||||
|
system_prompt=prompt,
|
||||||
|
subagents=subagent_specs,
|
||||||
|
name="home_supervisor",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_floating_graph(
|
||||||
|
user_id: str,
|
||||||
|
memory_context: dict[str, Any],
|
||||||
|
scope: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
):
|
||||||
|
"""Build the Floating supervisor graph."""
|
||||||
|
subagent_specs = _make_subagent_specs()
|
||||||
|
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
|
||||||
|
|
||||||
|
scope_type = scope.get("type", "general")
|
||||||
|
scope_id = scope.get("id")
|
||||||
|
scope_detail = f" (id: {scope_id})" if scope_id else ""
|
||||||
|
|
||||||
|
prompt = _FLOATING_SYSTEM.format(
|
||||||
|
scope_type=scope_type,
|
||||||
|
scope_detail=scope_detail,
|
||||||
|
memory_context=_format_memory_context(memory_context),
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_deep_agent(
|
||||||
|
model=get_llm(),
|
||||||
|
tools=[memory_tool],
|
||||||
|
system_prompt=prompt,
|
||||||
|
subagents=subagent_specs,
|
||||||
|
name="floating_supervisor",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stream event type ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Events yielded by run_*_stream:
|
||||||
|
# ("token", str) — text token for streaming
|
||||||
|
# ("tool_start", dict) — {"name": "task_agent", "args": {...}}
|
||||||
|
# ("tool_end", dict) — {"name": "task_agent", "result": "..."}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stream runners ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _run_graph_stream(
|
||||||
|
graph,
|
||||||
|
message: str,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Run a supervisor graph with streaming, yielding event tuples.
|
||||||
|
|
||||||
|
Uses ``stream_mode=["messages", "updates"]`` to get both token-level
|
||||||
|
streaming and update events for tool calls.
|
||||||
|
"""
|
||||||
|
inputs = {"messages": [HumanMessage(content=message)]}
|
||||||
|
|
||||||
|
collector: list[dict] = []
|
||||||
|
set_tool_result_collector(collector)
|
||||||
|
try:
|
||||||
|
async for stream_mode, chunk in graph.astream(
|
||||||
|
inputs,
|
||||||
|
stream_mode=["messages", "updates"],
|
||||||
|
):
|
||||||
|
if stream_mode == "messages":
|
||||||
|
msg, metadata = chunk
|
||||||
|
agent_name = (
|
||||||
|
metadata.get("lc_agent_name", "?")
|
||||||
|
if isinstance(metadata, dict) else "?"
|
||||||
|
)
|
||||||
|
node = (
|
||||||
|
metadata.get("langgraph_node", "?")
|
||||||
|
if isinstance(metadata, dict) else "?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log every message event with agent attribution
|
||||||
|
if isinstance(msg, (AIMessage, AIMessageChunk)) and msg.content:
|
||||||
|
logger.info(
|
||||||
|
"[%s] %s node=%s content=%s",
|
||||||
|
agent_name,
|
||||||
|
type(msg).__name__,
|
||||||
|
node,
|
||||||
|
str(msg.content),
|
||||||
|
)
|
||||||
|
elif isinstance(msg, (AIMessage, AIMessageChunk)) and msg.tool_calls:
|
||||||
|
tool_names = [tc["name"] for tc in msg.tool_calls]
|
||||||
|
logger.info(
|
||||||
|
"[%s] %s node=%s tool_calls=%s",
|
||||||
|
agent_name,
|
||||||
|
type(msg).__name__,
|
||||||
|
node,
|
||||||
|
tool_names,
|
||||||
|
)
|
||||||
|
elif hasattr(msg, "name") and hasattr(msg, "content") and msg.content:
|
||||||
|
# ToolMessage — log tool result
|
||||||
|
logger.info(
|
||||||
|
"[%s] ToolMessage tool=%s node=%s result=%s",
|
||||||
|
agent_name,
|
||||||
|
getattr(msg, "name", "?"),
|
||||||
|
node,
|
||||||
|
str(msg.content),
|
||||||
|
)
|
||||||
|
# Only yield tokens from the supervisor's final response
|
||||||
|
# (not from sub-agent internal LLM calls).
|
||||||
|
# Accept both AIMessageChunk (streamed tokens) and AIMessage
|
||||||
|
# (full response from non-streaming providers).
|
||||||
|
# create_deep_agent names the LLM node "model".
|
||||||
|
if (
|
||||||
|
isinstance(msg, (AIMessage, AIMessageChunk))
|
||||||
|
and msg.content
|
||||||
|
and not msg.tool_calls
|
||||||
|
and isinstance(metadata, dict)
|
||||||
|
and metadata.get("langgraph_node") == "model"
|
||||||
|
):
|
||||||
|
yield ("token", str(msg.content))
|
||||||
|
|
||||||
|
elif stream_mode == "updates":
|
||||||
|
# Updates is a dict of {node_name: state_update}
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
continue
|
||||||
|
for node_name, state_update in chunk.items():
|
||||||
|
if node_name != "tools":
|
||||||
|
continue
|
||||||
|
# Tool node executed — extract tool call results
|
||||||
|
tool_messages = state_update.get("messages", [])
|
||||||
|
for tool_msg in tool_messages:
|
||||||
|
if hasattr(tool_msg, "name") and hasattr(tool_msg, "content"):
|
||||||
|
yield (
|
||||||
|
"tool_end",
|
||||||
|
{"name": tool_msg.name, "result": str(tool_msg.content)},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
# Yield the collected mutations so callers can attach them to stream_end
|
||||||
|
yield ("mutations", collector)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Run the Home supervisor and yield streaming events."""
|
||||||
|
graph = build_home_graph(user_id, context, db_session_factory)
|
||||||
|
async for event in _run_graph_stream(graph, message):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
scope: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Run the Floating supervisor and yield streaming events."""
|
||||||
|
graph = build_floating_graph(user_id, context, scope, db_session_factory)
|
||||||
|
async for event in _run_graph_stream(graph, message):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
) -> str:
|
||||||
|
"""Run the Home supervisor (non-streaming) and return full response text."""
|
||||||
|
graph = build_home_graph(user_id, context, db_session_factory)
|
||||||
|
result = await graph.ainvoke(
|
||||||
|
{"messages": [HumanMessage(content=message)]}
|
||||||
|
)
|
||||||
|
messages = result["messages"]
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None):
|
||||||
|
return str(msg.content)
|
||||||
|
return ""
|
||||||
@@ -1,222 +0,0 @@
|
|||||||
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.schemas import ExecutionPlan, PlanStep
|
|
||||||
|
|
||||||
|
|
||||||
# ── Prompt Template Registry ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateRegistry:
|
|
||||||
"""Server-side store mapping template IDs to prompt text.
|
|
||||||
|
|
||||||
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
|
||||||
The actual prompt text is resolved here on the server, keeping prompt IP
|
|
||||||
out of API responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._templates: dict[str, str] = {}
|
|
||||||
|
|
||||||
def register(self, template_id: str, prompt_text: str) -> None:
|
|
||||||
self._templates[template_id] = prompt_text
|
|
||||||
|
|
||||||
def get(self, template_id: str) -> str:
|
|
||||||
"""Resolve a template ID to its prompt text.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the template is not registered.
|
|
||||||
"""
|
|
||||||
text = self._templates.get(template_id)
|
|
||||||
if text is None:
|
|
||||||
raise KeyError(f"Template not found: {template_id!r}")
|
|
||||||
return text
|
|
||||||
|
|
||||||
def has(self, template_id: str) -> bool:
|
|
||||||
return template_id in self._templates
|
|
||||||
|
|
||||||
def list_ids(self) -> list[str]:
|
|
||||||
"""Return all registered template IDs (never the text)."""
|
|
||||||
return list(self._templates.keys())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plan Builder ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlanBuilder:
|
|
||||||
"""Fluent builder for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent: str) -> None:
|
|
||||||
self._agent = agent
|
|
||||||
self._steps: list[PlanStep] = []
|
|
||||||
|
|
||||||
# ── step adders ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def add_step(
|
|
||||||
self, action: str, params: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a generic action step with optional parameters."""
|
|
||||||
self._steps.append(PlanStep(action=action, variables=params))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_llm_step(
|
|
||||||
self, template_id: str, variables: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append an LLM step referencing a server-side template by ID."""
|
|
||||||
self._steps.append(
|
|
||||||
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a step whose input comes from the output of an earlier step."""
|
|
||||||
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
|
||||||
return self
|
|
||||||
|
|
||||||
# ── build ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def build(self) -> ExecutionPlan:
|
|
||||||
"""Validate step references and return the ``ExecutionPlan``.
|
|
||||||
|
|
||||||
Raises ``ValueError`` if any ``data_from_step`` references a
|
|
||||||
non-existent or future step index.
|
|
||||||
"""
|
|
||||||
for i, step in enumerate(self._steps):
|
|
||||||
if step.data_from_step is not None:
|
|
||||||
if not (0 <= step.data_from_step < i):
|
|
||||||
raise ValueError(
|
|
||||||
f"Step {i}: data_from_step={step.data_from_step} must "
|
|
||||||
f"reference a preceding step index in range 0..{i - 1}"
|
|
||||||
)
|
|
||||||
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PlanCache:
|
|
||||||
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
|
||||||
The cache also serves as a runtime memoisation layer so that repeated
|
|
||||||
identical intent classifications can skip re-building the plan.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, maxsize: int = 1000) -> None:
|
|
||||||
self._maxsize = maxsize
|
|
||||||
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
|
||||||
|
|
||||||
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
|
||||||
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
|
||||||
if key in self._cache:
|
|
||||||
del self._cache[key] # remove so re-insertion places it at the end
|
|
||||||
elif len(self._cache) >= self._maxsize:
|
|
||||||
self._cache.popitem(last=False) # evict least-recently-used
|
|
||||||
self._cache[key] = plan
|
|
||||||
|
|
||||||
def get_plan(self, key: str) -> ExecutionPlan | None:
|
|
||||||
"""Return the cached plan for *key*, or ``None`` if not present.
|
|
||||||
|
|
||||||
Accessing a plan marks it as most-recently used.
|
|
||||||
"""
|
|
||||||
if key not in self._cache:
|
|
||||||
return None
|
|
||||||
self._cache.move_to_end(key)
|
|
||||||
return self._cache[key]
|
|
||||||
|
|
||||||
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached plans (most-recently used last)."""
|
|
||||||
return list(self._cache.values())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Module-level singletons ───────────────────────────────────────────
|
|
||||||
|
|
||||||
template_registry = PromptTemplateRegistry()
|
|
||||||
plan_cache = PlanCache()
|
|
||||||
|
|
||||||
|
|
||||||
def _register_builtin_templates() -> None:
|
|
||||||
"""Register the built-in server-side prompt templates.
|
|
||||||
|
|
||||||
These strings never leave the server. Clients only receive the IDs.
|
|
||||||
"""
|
|
||||||
_tpls: dict[str, str] = {
|
|
||||||
"tpl_task_agent_default": (
|
|
||||||
"You are a task management assistant. Help the user create, update, "
|
|
||||||
"list, and track tasks. Use correct status values (todo, in_progress, "
|
|
||||||
"done) and priority values (high, medium, low) from the workspace model."
|
|
||||||
),
|
|
||||||
"tpl_checkpoint_agent_default": (
|
|
||||||
"You are a project checkpoint assistant. Help the user create and manage "
|
|
||||||
"milestone checkpoints on their projects. Every checkpoint requires a "
|
|
||||||
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
|
||||||
),
|
|
||||||
"tpl_project_agent_default": (
|
|
||||||
"You are a project management assistant. Help the user create, find, "
|
|
||||||
"update, and archive projects. Projects have a name, an optional client, "
|
|
||||||
"and a status of either active or archived."
|
|
||||||
),
|
|
||||||
"tpl_note_agent_default": (
|
|
||||||
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
|
||||||
"and delete Markdown notes. Notes can optionally be linked to a project."
|
|
||||||
),
|
|
||||||
"tpl_task_extract_from_project": (
|
|
||||||
"Extract all actionable tasks from the provided project context. "
|
|
||||||
"Return a structured list of tasks, each with a title, inferred priority "
|
|
||||||
"(high, medium, or low), suggested status (todo), and a due_date in "
|
|
||||||
"milliseconds where a deadline can be inferred."
|
|
||||||
),
|
|
||||||
"tpl_note_weekly_summary": (
|
|
||||||
"Generate a weekly project summary note from the provided workspace data. "
|
|
||||||
"Include: tasks completed this week, tasks due soon, active projects, "
|
|
||||||
"and upcoming checkpoints. Format the output as clean Markdown."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for tid, text in _tpls.items():
|
|
||||||
template_registry.register(tid, text)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_playbooks() -> None:
|
|
||||||
"""Pre-build and cache the built-in playbooks."""
|
|
||||||
playbooks: list[tuple[str, ExecutionPlan]] = [
|
|
||||||
(
|
|
||||||
"create_tasks_from_project",
|
|
||||||
ExecutionPlanBuilder("project_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_task_extract_from_project",
|
|
||||||
{"source": "project_context"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"generate_weekly_note",
|
|
||||||
ExecutionPlanBuilder("note_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_note_weekly_summary",
|
|
||||||
{"period": "last_7_days"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for key, plan in playbooks:
|
|
||||||
plan_cache.cache_plan(key, plan)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialise on module load
|
|
||||||
_register_builtin_templates()
|
|
||||||
_load_playbooks()
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
|
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
|
||||||
instead of directly constructing a provider-specific class. The model string
|
instead of directly constructing a provider-specific class. The model string
|
||||||
follows the `LiteLLM model naming convention
|
follows the `LiteLLM model naming convention
|
||||||
<https://docs.litellm.ai/docs/providers>`_:
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
@@ -17,13 +17,21 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
|
||||||
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
@@ -31,6 +39,12 @@ def _api_key_for_model(model: str) -> str | None:
|
|||||||
return settings.ANTHROPIC_API_KEY or None
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
if model.startswith("gemini/") or model.startswith("google/"):
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
return settings.GOOGLE_API_KEY or None
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("cerebras/"):
|
||||||
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||||
|
# No API key is required; returning None lets LiteLLM handle auth.
|
||||||
|
return None
|
||||||
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
||||||
return settings.OPENAI_API_KEY or None
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
@@ -39,7 +53,7 @@ def get_llm(
|
|||||||
*,
|
*,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
) -> ChatOpenAI:
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
"""Return a LangChain chat model backed by LiteLLM.
|
"""Return a LangChain chat model backed by LiteLLM.
|
||||||
|
|
||||||
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
|
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
|
||||||
@@ -55,6 +69,16 @@ def get_llm(
|
|||||||
Sampling temperature. ``0`` = deterministic.
|
Sampling temperature. ``0`` = deterministic.
|
||||||
"""
|
"""
|
||||||
model = model or settings.LLM_MODEL
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
# Point LiteLLM to the custom token directory when configured.
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
|
||||||
|
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
|
||||||
|
if "/" in model:
|
||||||
|
return ChatLiteLLM(model=model, temperature=temperature)
|
||||||
|
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
model=model,
|
model=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -65,16 +89,28 @@ def get_llm(
|
|||||||
def get_router_llm(
|
def get_router_llm(
|
||||||
*,
|
*,
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
) -> ChatOpenAI:
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
"""Return the lighter model used for intent classification / routing."""
|
"""Return the lighter model used for intent classification / routing."""
|
||||||
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
async def embed(text: str) -> list[float]:
|
||||||
"""Return a 1536-dim embedding vector for *text* using text-embedding-3-small."""
|
"""Return an embedding vector for *text*.
|
||||||
|
|
||||||
|
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
|
||||||
|
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
|
||||||
|
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
|
||||||
|
model names to preserve existing behaviour.
|
||||||
|
"""
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
|
||||||
|
# so the provider's auth mechanism is applied correctly.
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
|
||||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
response = await client.embeddings.create(
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
model="text-embedding-3-small",
|
|
||||||
input=text,
|
|
||||||
)
|
|
||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
|
|||||||
231
app/core/memory_middleware.py
Normal file
231
app/core/memory_middleware.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""Memory Middleware — enrich requests with memory context and store interactions.
|
||||||
|
|
||||||
|
Four-tier memory model (MemGPT-style):
|
||||||
|
core — persistent key/value user preferences, always injected
|
||||||
|
associative — semantic similarity search via pgvector (top-k)
|
||||||
|
episodic — recent session summaries (last N)
|
||||||
|
proactive — behavioral patterns above confidence threshold
|
||||||
|
|
||||||
|
All memory content is encrypted at rest using the per-user Fernet key
|
||||||
|
stored in User.encryption_key. Decryption happens in-memory only.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
memory = MemoryMiddleware(db_session)
|
||||||
|
context = await memory.enrich_context(user_id, message)
|
||||||
|
# ... run agent ...
|
||||||
|
await memory.store_episode(user_id, session_id, message, response)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tuning constants
|
||||||
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
|
_EPISODIC_RECENT_N = 10
|
||||||
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMiddleware:
|
||||||
|
"""Enrich agent context with memory and persist interactions after."""
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
||||||
|
"""Build memory context dict to inject into the agent before LLM call.
|
||||||
|
|
||||||
|
Returns a dict with keys:
|
||||||
|
core_memory — {key: plaintext_value, ...}
|
||||||
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||||
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||||
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||||
|
"""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
core = await self._load_core(user_id, fernet)
|
||||||
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
|
episodic = await self._load_episodic(user_id, fernet)
|
||||||
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"core_memory": core,
|
||||||
|
"associative_memory": associative,
|
||||||
|
"episodic_memory": episodic,
|
||||||
|
"proactive_hints": proactive,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def store_episode(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
response: str,
|
||||||
|
) -> None:
|
||||||
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||||
|
latency low. Full LLM summarisation can be added in a later step.
|
||||||
|
"""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||||
|
encrypted = _encrypt(fernet, summary)
|
||||||
|
|
||||||
|
row = MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
summary_encrypted=encrypted,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
||||||
|
"""Upsert a core memory key/value for a user."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, value)
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
existing.value_encrypted = encrypted
|
||||||
|
else:
|
||||||
|
self._db.add(MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
key=key,
|
||||||
|
value_encrypted=encrypted,
|
||||||
|
))
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
|
"""Load the user's Fernet key from DB. Returns None if missing."""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||||
|
return None
|
||||||
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: dict[str, str] = {}
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out[row.key] = plaintext
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_associative(
|
||||||
|
self, user_id: str, message: str, fernet: Fernet
|
||||||
|
) -> list[str]:
|
||||||
|
"""Load top-k associative memories.
|
||||||
|
|
||||||
|
Production: uses pgvector cosine similarity on the message embedding.
|
||||||
|
Current implementation: keyword-based fallback (no external embedding call)
|
||||||
|
so tests pass without a live OpenAI key.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(_EPISODIC_RECENT_N)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryProactive)
|
||||||
|
.where(
|
||||||
|
MemoryProactive.user_id == user_id,
|
||||||
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||||
|
)
|
||||||
|
.order_by(MemoryProactive.confidence.desc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ── Encryption helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||||
|
return fernet.encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||||
|
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
||||||
|
try:
|
||||||
|
return fernet.decrypt(ciphertext.encode()).decode()
|
||||||
|
except (InvalidToken, Exception) as exc:
|
||||||
|
logger.warning("memory: decrypt failed: %s", exc)
|
||||||
|
return None
|
||||||
@@ -1,166 +0,0 @@
|
|||||||
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, AsyncGenerator
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry
|
|
||||||
from app.core.llm import get_router_llm
|
|
||||||
from app.core.agent_registry import registry as _default_registry
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
|
||||||
|
|
||||||
_FALLBACK_AGENT = "task_agent"
|
|
||||||
|
|
||||||
_CLASSIFY_SYSTEM = (
|
|
||||||
"You are an intent classifier. Given the user message and context, decide "
|
|
||||||
"which agent to route to.\n"
|
|
||||||
"Available agents: {agents}\n"
|
|
||||||
"Respond with just the agent name, nothing else."
|
|
||||||
)
|
|
||||||
|
|
||||||
_SYNTHESIZE_HUMAN = (
|
|
||||||
"Combine the following agent results into one coherent response.\n\n"
|
|
||||||
"Agent results:\n{results}\n\n"
|
|
||||||
"Original message: {message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_llm():
|
|
||||||
return get_router_llm()
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_intent(
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> str:
|
|
||||||
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
|
||||||
|
|
||||||
Falls back to ``task_agent`` when the registry is empty or the model
|
|
||||||
returns a name that is not registered.
|
|
||||||
"""
|
|
||||||
agents = reg.list_agents()
|
|
||||||
if not agents:
|
|
||||||
return _FALLBACK_AGENT
|
|
||||||
|
|
||||||
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
|
||||||
# Truncate context to keep the classification prompt short
|
|
||||||
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
|
||||||
|
|
||||||
llm = _make_llm()
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[SystemMessage(content=system), HumanMessage(content=human)]
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = str(response.content).strip().lower()
|
|
||||||
known = {a["name"] for a in agents}
|
|
||||||
return agent_name if agent_name in known else _FALLBACK_AGENT
|
|
||||||
|
|
||||||
|
|
||||||
async def route_single(
|
|
||||||
agent_name: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
|
||||||
response_text = await reg.call_agent(agent_name, message, context)
|
|
||||||
return ChatResponse(response=response_text)
|
|
||||||
|
|
||||||
|
|
||||||
async def route_pipeline(
|
|
||||||
agent_names: list[str],
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Execute agents sequentially; each agent receives previous results in context.
|
|
||||||
|
|
||||||
A final LLM synthesis call merges all results into one coherent response.
|
|
||||||
"""
|
|
||||||
previous_results: list[str] = []
|
|
||||||
|
|
||||||
for agent_name in agent_names:
|
|
||||||
ctx = {**context, "previous_results": list(previous_results)}
|
|
||||||
result = await reg.call_agent(agent_name, message, ctx)
|
|
||||||
previous_results.append(result)
|
|
||||||
|
|
||||||
results_str = "\n\n".join(
|
|
||||||
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
|
||||||
)
|
|
||||||
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
|
||||||
llm = _make_llm()
|
|
||||||
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
|
||||||
return ChatResponse(response=str(synthesis.content))
|
|
||||||
|
|
||||||
|
|
||||||
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
|
||||||
"""Build an ``ExecutionPlan`` for the resolved agent.
|
|
||||||
|
|
||||||
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
|
||||||
If a default template exists for the agent, an LLM step is emitted;
|
|
||||||
otherwise a plain ``handle`` action step is used.
|
|
||||||
"""
|
|
||||||
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
|
||||||
|
|
||||||
template_id = f"tpl_{agent_name}_default"
|
|
||||||
builder = ExecutionPlanBuilder(agent_name)
|
|
||||||
if template_registry.has(template_id):
|
|
||||||
builder.add_llm_step(template_id, {"message": message})
|
|
||||||
else:
|
|
||||||
builder.add_step("handle", {"message": message})
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> ChatResponse | ExecutionPlan:
|
|
||||||
"""Main orchestration entry point.
|
|
||||||
|
|
||||||
* Classifies the user's intent to select an agent.
|
|
||||||
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
|
||||||
``ChatResponse``.
|
|
||||||
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
|
||||||
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
|
|
||||||
if request.execution_mode == "direct":
|
|
||||||
return await route_single(agent_name, request.message, context, reg)
|
|
||||||
|
|
||||||
# plan mode — return plan, do not execute
|
|
||||||
return _build_plan(agent_name, request.message)
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_stream(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming orchestration — yields plain text chunks only.
|
|
||||||
|
|
||||||
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
|
|
||||||
wrapping each chunk in a ``text_chunk`` frame and sending the final
|
|
||||||
``final`` frame once the generator is exhausted.
|
|
||||||
|
|
||||||
Agents do not yet support token-level streaming; the full response is
|
|
||||||
fetched first (which may involve multiple WS round-trips for tool calls),
|
|
||||||
then emitted in fixed-size chunks.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
response_text = await reg.call_agent(agent_name, request.message, context)
|
|
||||||
|
|
||||||
chunk_size = 50
|
|
||||||
for i in range(0, len(response_text), chunk_size):
|
|
||||||
yield response_text[i : i + chunk_size]
|
|
||||||
141
app/core/output_formatter.py
Normal file
141
app/core/output_formatter.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
||||||
|
|
||||||
|
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
||||||
|
* ``("token", str)`` — supervisor text token
|
||||||
|
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
||||||
|
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
||||||
|
|
||||||
|
HomeFormatter:
|
||||||
|
* Streams text tokens as-is → emits ``WsStreamText``
|
||||||
|
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
||||||
|
for the frontend to parse and render as interactive components)
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
|
||||||
|
FloatingFormatter:
|
||||||
|
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
||||||
|
* Streams text tokens → emits ``WsStreamText``
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Map sub-agent tool name → floating domain / entity type
|
||||||
|
_AGENT_DOMAIN: dict[str, str] = {
|
||||||
|
"task_agent": "tasks",
|
||||||
|
"timeline_agent": "timelines",
|
||||||
|
"note_agent": "notes",
|
||||||
|
"project_agent": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
|
class HomeFormatter:
|
||||||
|
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
||||||
|
|
||||||
|
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
||||||
|
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
||||||
|
is responsible for parsing those and rendering interactive components.
|
||||||
|
Mutations are attached to ``WsStreamEnd``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "token":
|
||||||
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FloatingFormatter:
|
||||||
|
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
||||||
|
|
||||||
|
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
||||||
|
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
||||||
|
``WsStreamText``. No block parsing for floating context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
domain_sent = False
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "tool_end" and not domain_sent:
|
||||||
|
# Sniff domain from the first sub-agent that completes
|
||||||
|
name = data.get("name", "")
|
||||||
|
domain = _AGENT_DOMAIN.get(name, "tasks")
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain=domain, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
|
||||||
|
elif event_type == "token":
|
||||||
|
if not domain_sent:
|
||||||
|
# First token arrived before any tool_end — default domain
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
# If no events triggered domain_sent (edge case), still emit structure
|
||||||
|
if not domain_sent:
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
@@ -7,16 +7,35 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Holds the execute callback for the current WS session.
|
# Holds the execute callback for the current WS session.
|
||||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
# Set by the chat WS handler before the deep agent runs; cleared after.
|
||||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||||
"_client_executor"
|
"_client_executor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Optional collector that captures raw execute_on_client results.
|
||||||
|
# Set by the deep agent tool loop to capture CRUD mutations.
|
||||||
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
|
"_tool_result_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||||
|
"""Register *lst* as the collector for this async context."""
|
||||||
|
_tool_result_collector.set(lst)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_tool_result_collector() -> None:
|
||||||
|
"""Clear the collector (best-effort)."""
|
||||||
|
_tool_result_collector.set(None)
|
||||||
|
|
||||||
|
|
||||||
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||||
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||||
@@ -65,4 +84,17 @@ async def execute_on_client(
|
|||||||
if limit is not None:
|
if limit is not None:
|
||||||
payload["limit"] = limit
|
payload["limit"] = limit
|
||||||
|
|
||||||
return await callback(payload)
|
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
|
result = await callback(payload)
|
||||||
|
if result is None:
|
||||||
|
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
|
else:
|
||||||
|
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
|
collector = _tool_result_collector.get(None)
|
||||||
|
if collector is not None and action in ("insert", "update", "delete"):
|
||||||
|
collector.append({
|
||||||
|
"action": action,
|
||||||
|
"table": table,
|
||||||
|
"data": data or {},
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from app.config.settings import settings
|
|||||||
engine = create_async_engine(
|
engine = create_async_engine(
|
||||||
settings.DATABASE_URL,
|
settings.DATABASE_URL,
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
echo=settings.ENV == "dev",
|
echo=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|||||||
164
app/integrations/__init__.py
Normal file
164
app/integrations/__init__.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Cloud provider integration utilities.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
|
||||||
|
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
|
||||||
|
* ``get_provider()`` — factory that returns the correct client given a
|
||||||
|
provider name and decrypted OAuth credentials dict.
|
||||||
|
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
|
||||||
|
encryption for OAuth tokens stored in ``cloud_agent_configs``.
|
||||||
|
|
||||||
|
Encryption rationale
|
||||||
|
--------------------
|
||||||
|
Unlike user content (which is E2E-encrypted client-side and **never**
|
||||||
|
decrypted server-side), OAuth tokens *must* be decrypted server-side
|
||||||
|
because the backend makes provider API calls on behalf of the user.
|
||||||
|
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
|
||||||
|
is never returned to clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Shared message types ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmailMessage:
|
||||||
|
"""A single email message fetched from Gmail or Outlook."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
subject: str
|
||||||
|
sender: str
|
||||||
|
body_text: str
|
||||||
|
date: datetime
|
||||||
|
labels: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
"""Return a human-readable text representation for LLM extraction."""
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{labels_str}\n"
|
||||||
|
f"Subject: {self.subject}\n\n"
|
||||||
|
f"{self.body_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
"""A single Teams chat or channel message fetched from MS Graph."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
sender: str
|
||||||
|
channel: str | None
|
||||||
|
date: datetime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
"""Return a human-readable text representation for LLM extraction."""
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{channel_str}\n\n"
|
||||||
|
f"{self.content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fernet helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet() -> Fernet:
|
||||||
|
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
|
||||||
|
must ensure this is configured before persisting OAuth tokens.
|
||||||
|
"""
|
||||||
|
key = settings.OAUTH_ENCRYPTION_KEY
|
||||||
|
if not key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"OAUTH_ENCRYPTION_KEY is not set. "
|
||||||
|
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||||
|
)
|
||||||
|
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_token(token_info: dict) -> str:
|
||||||
|
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
|
||||||
|
|
||||||
|
Stores the full ``{access_token, refresh_token, token_uri, client_id,
|
||||||
|
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||||
|
ValueError: ``token_info`` is not a non-empty dict.
|
||||||
|
"""
|
||||||
|
if not isinstance(token_info, dict) or not token_info:
|
||||||
|
raise ValueError("token_info must be a non-empty dict")
|
||||||
|
plaintext = json.dumps(token_info).encode("utf-8")
|
||||||
|
return _get_fernet().encrypt(plaintext).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_token(encrypted: str) -> dict:
|
||||||
|
"""Decrypt a Fernet-encrypted token string and return the credential dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||||
|
ValueError: The encrypted string is invalid or was encrypted with a
|
||||||
|
different key.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||||
|
return json.loads(plaintext)
|
||||||
|
except (InvalidToken, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ── Provider factory ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
provider: str,
|
||||||
|
credentials_info: dict,
|
||||||
|
) -> "GmailClient | MSGraphClient":
|
||||||
|
"""Return the correct provider client for *provider*.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
provider:
|
||||||
|
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
|
||||||
|
credentials_info:
|
||||||
|
Decrypted OAuth credential dict (Google or Microsoft shape).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Unknown provider name.
|
||||||
|
"""
|
||||||
|
if provider == "gmail":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(credentials_info)
|
||||||
|
if provider in {"outlook", "teams"}:
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(credentials_info)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown cloud provider {provider!r}. "
|
||||||
|
"Supported: 'gmail', 'outlook', 'teams'."
|
||||||
|
)
|
||||||
335
app/integrations/gmail.py
Normal file
335
app/integrations/gmail.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
"""Gmail API client for cloud agent integration.
|
||||||
|
|
||||||
|
Wraps the Google Gmail REST API to fetch email messages matching a
|
||||||
|
``filter_config`` dict. Uses the official ``google-api-python-client``
|
||||||
|
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
|
||||||
|
blocking the event loop.
|
||||||
|
|
||||||
|
Token refresh is handled transparently: when the stored access token has
|
||||||
|
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
||||||
|
token to obtain a fresh one. The caller is responsible for persisting
|
||||||
|
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
|
||||||
|
(see ``agent_runner.run_cloud_agent``).
|
||||||
|
|
||||||
|
Credential dict shape (Google OAuth2):
|
||||||
|
{
|
||||||
|
"token": "<access_token>",
|
||||||
|
"refresh_token": "<refresh_token>",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "<client_id>",
|
||||||
|
"client_secret": "<client_secret>",
|
||||||
|
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||||
|
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import email
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.integrations import EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Gmail search date format — e.g. "after:2025/01/01"
|
||||||
|
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||||
|
|
||||||
|
# Maximum characters of body text forwarded to the LLM.
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
# Maximum messages retrieved per run (prevents runaway quota usage).
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gmail_query(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a Gmail search query string from *filter_config* and *since*.
|
||||||
|
|
||||||
|
Supported ``filter_config`` keys:
|
||||||
|
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
|
||||||
|
senders (list[str]): Sender addresses or domains to include
|
||||||
|
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
|
||||||
|
|
||||||
|
A hard ``since`` date (from last run) always overrides ``date_range.from``
|
||||||
|
when it is earlier.
|
||||||
|
"""
|
||||||
|
parts: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
# Labels — joined with OR when multiple given.
|
||||||
|
labels: list[str] = cfg.get("labels", [])
|
||||||
|
if labels:
|
||||||
|
if len(labels) == 1:
|
||||||
|
parts.append(f"label:{labels[0]}")
|
||||||
|
else:
|
||||||
|
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||||
|
parts.append(f"({label_expr})")
|
||||||
|
|
||||||
|
# Senders — each prefixed with "from:".
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
for sender in senders:
|
||||||
|
parts.append(f"from:{sender}")
|
||||||
|
|
||||||
|
# Date range.
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
|
||||||
|
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw_html: str) -> str:
|
||||||
|
"""Remove HTML tags and decode entities to get plain text."""
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||||
|
decoded = html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_body(payload: dict[str, Any]) -> str:
|
||||||
|
"""Recursively extract the plain-text body from a Gmail message payload.
|
||||||
|
|
||||||
|
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
|
||||||
|
Returns an empty string if no body can be extracted.
|
||||||
|
"""
|
||||||
|
mime_type: str = payload.get("mimeType", "")
|
||||||
|
body: dict = payload.get("body", {})
|
||||||
|
parts: list[dict] = payload.get("parts", [])
|
||||||
|
|
||||||
|
if mime_type == "text/plain":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if mime_type == "text/html":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return _strip_html(raw)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Multipart — prefer text/plain part, fall back to text/html.
|
||||||
|
plain_fallback = ""
|
||||||
|
for part in parts:
|
||||||
|
part_mime = part.get("mimeType", "")
|
||||||
|
if part_mime == "text/plain":
|
||||||
|
return _parse_body(part)
|
||||||
|
if part_mime == "text/html" and not plain_fallback:
|
||||||
|
plain_fallback = _parse_body(part)
|
||||||
|
if part_mime.startswith("multipart/"):
|
||||||
|
nested = _parse_body(part)
|
||||||
|
if nested:
|
||||||
|
return nested
|
||||||
|
return plain_fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_date(raw: str) -> datetime:
|
||||||
|
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
|
||||||
|
try:
|
||||||
|
parsed = email.utils.parsedate_to_datetime(raw)
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||||
|
return parsed.astimezone(timezone.utc)
|
||||||
|
except Exception:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class GmailClient:
|
||||||
|
"""Fetch email messages from a Gmail account via the Gmail REST API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
credentials_info:
|
||||||
|
Decrypted OAuth2 credential dict. Must contain at minimum
|
||||||
|
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
|
||||||
|
``client_id`` + ``client_secret``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
expiry_str: str | None = credentials_info.get("expiry")
|
||||||
|
expiry: datetime | None = None
|
||||||
|
if expiry_str:
|
||||||
|
try:
|
||||||
|
expiry = datetime.fromisoformat(
|
||||||
|
expiry_str.replace("Z", "+00:00")
|
||||||
|
).replace(tzinfo=timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._credentials = Credentials(
|
||||||
|
token=credentials_info.get("token"),
|
||||||
|
refresh_token=credentials_info.get("refresh_token"),
|
||||||
|
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||||
|
client_id=credentials_info.get("client_id"),
|
||||||
|
client_secret=credentials_info.get("client_secret"),
|
||||||
|
scopes=credentials_info.get("scopes"),
|
||||||
|
expiry=expiry,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Public API ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
|
||||||
|
|
||||||
|
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
|
||||||
|
to avoid blocking the async event loop.
|
||||||
|
|
||||||
|
Token refresh is performed automatically when the access token has
|
||||||
|
expired. After the call, ``self.refreshed_credentials`` may be
|
||||||
|
consulted to detect whether new credentials should be persisted.
|
||||||
|
"""
|
||||||
|
query = _build_gmail_query(filter_config, since)
|
||||||
|
logger.debug("gmail: executing search query %r", query)
|
||||||
|
return await asyncio.to_thread(self._fetch_sync, query)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
"""Return updated credential dict if the access token was refreshed.
|
||||||
|
|
||||||
|
If the credentials were refreshed during ``fetch_messages()``, returns
|
||||||
|
a new dict that should be re-encrypted and written back to the DB.
|
||||||
|
Returns ``None`` if no refresh occurred.
|
||||||
|
"""
|
||||||
|
creds = self._credentials
|
||||||
|
if not creds.valid and creds.expired:
|
||||||
|
return None
|
||||||
|
# Check whether the token changed from what was stored.
|
||||||
|
if creds.token != self._credentials_info.get("token"):
|
||||||
|
result = {
|
||||||
|
"token": creds.token,
|
||||||
|
"refresh_token": creds.refresh_token,
|
||||||
|
"token_uri": creds.token_uri,
|
||||||
|
"client_id": creds.client_id,
|
||||||
|
"client_secret": creds.client_secret,
|
||||||
|
"scopes": list(creds.scopes or []),
|
||||||
|
}
|
||||||
|
if creds.expiry:
|
||||||
|
result["expiry"] = creds.expiry.isoformat()
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── Internal sync worker ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||||
|
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
|
||||||
|
import googleapiclient.discovery
|
||||||
|
import googleapiclient.errors
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
# Refresh token if needed before building the service.
|
||||||
|
if self._credentials.expired and self._credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
self._credentials.refresh(Request())
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||||
|
|
||||||
|
service = googleapiclient.discovery.build(
|
||||||
|
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||||
|
)
|
||||||
|
user_api = service.users() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# ── List matching message IDs ──────────────────────────────────────
|
||||||
|
ids: list[str] = []
|
||||||
|
page_token: str | None = None
|
||||||
|
while len(ids) < _MAX_MESSAGES:
|
||||||
|
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"userId": "me",
|
||||||
|
"maxResults": batch_size,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
kwargs["q"] = query
|
||||||
|
if page_token:
|
||||||
|
kwargs["pageToken"] = page_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = user_api.messages().list(**kwargs).execute()
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||||
|
|
||||||
|
for msg in resp.get("messages", []):
|
||||||
|
ids.append(msg["id"])
|
||||||
|
|
||||||
|
page_token = resp.get("nextPageToken")
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
logger.debug("gmail: no messages matched query %r", query)
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||||
|
|
||||||
|
# ── Fetch individual message details ──────────────────────────────
|
||||||
|
messages: list[EmailMessage] = []
|
||||||
|
for msg_id in ids:
|
||||||
|
try:
|
||||||
|
msg = user_api.messages().get(
|
||||||
|
userId="me", id=msg_id, format="full"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
h["name"].lower(): h["value"]
|
||||||
|
for h in msg.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
subject = headers.get("subject", "(no subject)")
|
||||||
|
sender = headers.get("from", "unknown")
|
||||||
|
date_raw = headers.get("date", "")
|
||||||
|
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||||
|
labels = msg.get("labelIds", [])
|
||||||
|
|
||||||
|
messages.append(EmailMessage(
|
||||||
|
id=msg_id,
|
||||||
|
subject=subject,
|
||||||
|
sender=sender,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
labels=labels,
|
||||||
|
))
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
|
||||||
|
|
||||||
|
logger.info("gmail: returned %d message(s)", len(messages))
|
||||||
|
return messages
|
||||||
352
app/integrations/ms_graph.py
Normal file
352
app/integrations/ms_graph.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
|
||||||
|
|
||||||
|
Handles two data sources:
|
||||||
|
|
||||||
|
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
|
||||||
|
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
|
||||||
|
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
|
||||||
|
``/me/chats/getAllMessages`` filtered by date.
|
||||||
|
|
||||||
|
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
|
||||||
|
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
|
||||||
|
dependency) is used for all API calls.
|
||||||
|
|
||||||
|
Credential dict shape (Microsoft OAuth2 / MSAL):
|
||||||
|
{
|
||||||
|
"access_token": "<access_token>",
|
||||||
|
"refresh_token": "<refresh_token>",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
|
||||||
|
"expires_in": 3600
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.integrations import ChatMessage, EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
# Max items fetched per run.
|
||||||
|
_MAX_EMAILS = 200
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
# Max characters of body forwarded to the LLM.
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw: str) -> str:
|
||||||
|
"""Strip HTML tags and collapse whitespace."""
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||||
|
import html as _html
|
||||||
|
decoded = _html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _odata_datetime(dt: datetime) -> str:
|
||||||
|
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
|
||||||
|
utc = dt.astimezone(timezone.utc)
|
||||||
|
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_email_filter(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
|
||||||
|
|
||||||
|
Supported ``filter_config`` keys:
|
||||||
|
senders (list[str]): Sender email addresses.
|
||||||
|
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
|
||||||
|
folders (list[str]): Folder display names (not directly filterable
|
||||||
|
via OData, so ignored here — callers iterate
|
||||||
|
folder IDs separately if needed; listed for
|
||||||
|
completeness).
|
||||||
|
|
||||||
|
A hard ``since`` date always overrides ``date_range.from`` when it is
|
||||||
|
earlier.
|
||||||
|
"""
|
||||||
|
clauses: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
# Senders.
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
if senders:
|
||||||
|
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||||
|
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||||
|
|
||||||
|
# Date range.
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
|
||||||
|
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
if to_dt.tzinfo is None:
|
||||||
|
to_dt = to_dt.replace(tzinfo=timezone.utc)
|
||||||
|
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " and ".join(clauses)
|
||||||
|
|
||||||
|
|
||||||
|
class MSGraphClient:
|
||||||
|
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
credentials_info:
|
||||||
|
Decrypted MSAL credential dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
self._access_token: str = credentials_info.get("access_token", "")
|
||||||
|
self._original_access_token: str = self._access_token
|
||||||
|
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||||
|
|
||||||
|
# ── Token management ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {self._access_token}"}
|
||||||
|
|
||||||
|
async def _refresh_access_token(self) -> None:
|
||||||
|
"""Use MSAL to exchange the refresh token for a fresh access token.
|
||||||
|
|
||||||
|
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: MSAL reports an auth error.
|
||||||
|
"""
|
||||||
|
import msal
|
||||||
|
|
||||||
|
app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=settings.MS_CLIENT_ID,
|
||||||
|
client_credential=settings.MS_CLIENT_SECRET,
|
||||||
|
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
|
||||||
|
)
|
||||||
|
scopes: list[str] = self._credentials_info.get("scope", "").split()
|
||||||
|
if not scopes:
|
||||||
|
scopes = ["https://graph.microsoft.com/.default"]
|
||||||
|
|
||||||
|
result = app.acquire_token_by_refresh_token(
|
||||||
|
self._refresh_token,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
if "access_token" not in result:
|
||||||
|
error = result.get("error_description", result.get("error", "unknown"))
|
||||||
|
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||||
|
|
||||||
|
self._access_token = result["access_token"]
|
||||||
|
# MSAL may issue a new refresh token.
|
||||||
|
if "refresh_token" in result:
|
||||||
|
self._refresh_token = result["refresh_token"]
|
||||||
|
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||||
|
self._credentials_info["access_token"] = self._access_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
"""Return updated credential dict if the access token was refreshed.
|
||||||
|
|
||||||
|
Returns ``None`` if no change was made.
|
||||||
|
"""
|
||||||
|
if self._access_token != self._original_access_token:
|
||||||
|
return {**self._credentials_info, "access_token": self._access_token}
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── HTTP helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
retry_on_401: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""GET *url* with auth; refresh token on 401 and retry once."""
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||||
|
logger.debug("ms_graph: 401 on %s — refreshing token", url)
|
||||||
|
await self._refresh_access_token()
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 429:
|
||||||
|
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
# ── Public API ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def fetch_emails(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filter_config:
|
||||||
|
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
|
||||||
|
since:
|
||||||
|
Hard lower-bound on email date (from last agent run).
|
||||||
|
"""
|
||||||
|
odata_filter = _build_email_filter(filter_config, since)
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"$top": 50,
|
||||||
|
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
|
||||||
|
"$orderby": "receivedDateTime desc",
|
||||||
|
}
|
||||||
|
if odata_filter:
|
||||||
|
params["$filter"] = odata_filter
|
||||||
|
|
||||||
|
emails: list[EmailMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/messages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(emails) < _MAX_EMAILS:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
for item in data.get("value", []):
|
||||||
|
emails.append(self._parse_email(item))
|
||||||
|
if len(emails) >= _MAX_EMAILS:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {} # nextLink already contains encoded params.
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||||
|
return emails
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
|
||||||
|
|
||||||
|
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
|
||||||
|
The ``filter_config.channels`` key is checked as a text-filter on
|
||||||
|
the channel name post-fetch (the API doesn't support channel OData
|
||||||
|
filter directly on ``getAllMessages``).
|
||||||
|
"""
|
||||||
|
cfg = filter_config or {}
|
||||||
|
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||||
|
params: dict[str, Any] = {"$top": 50}
|
||||||
|
if since:
|
||||||
|
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
|
||||||
|
|
||||||
|
messages: list[ChatMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(messages) < _MAX_MESSAGES:
|
||||||
|
try:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
# getAllMessages requires specific licensing; degrade gracefully.
|
||||||
|
if exc.response.status_code in (403, 404):
|
||||||
|
logger.warning(
|
||||||
|
"ms_graph: /me/chats/getAllMessages not available (%d) — "
|
||||||
|
"check Teams license or permissions",
|
||||||
|
exc.response.status_code,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
raise
|
||||||
|
|
||||||
|
for item in data.get("value", []):
|
||||||
|
msg = self._parse_teams_message(item)
|
||||||
|
if channel_filter and msg.channel:
|
||||||
|
if not any(c in msg.channel.lower() for c in channel_filter):
|
||||||
|
continue
|
||||||
|
messages.append(msg)
|
||||||
|
if len(messages) >= _MAX_MESSAGES:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# ── Parsers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||||
|
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||||
|
sender_block = item.get("from", {}) or {}
|
||||||
|
sender_addr = (
|
||||||
|
(sender_block.get("emailAddress") or {}).get("address", "unknown")
|
||||||
|
)
|
||||||
|
date_str: str = item.get("receivedDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_body: str = body_block.get("content", "")
|
||||||
|
if content_type == "html":
|
||||||
|
body_text = _strip_html(raw_body)
|
||||||
|
else:
|
||||||
|
body_text = raw_body or item.get("bodyPreview", "")
|
||||||
|
body_text = body_text[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return EmailMessage(
|
||||||
|
id=item.get("id", ""),
|
||||||
|
subject=subject,
|
||||||
|
sender=sender_addr,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
|
||||||
|
msg_id: str = item.get("id", "")
|
||||||
|
sender_block = (item.get("from") or {}).get("user") or {}
|
||||||
|
sender: str = sender_block.get("displayName", "unknown")
|
||||||
|
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
|
||||||
|
|
||||||
|
date_str: str = item.get("createdDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_content: str = body_block.get("content", "")
|
||||||
|
content = _strip_html(raw_content) if content_type == "html" else raw_content
|
||||||
|
content = content[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
id=msg_id,
|
||||||
|
content=content,
|
||||||
|
sender=sender,
|
||||||
|
channel=channel,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
16
app/main.py
16
app/main.py
@@ -1,8 +1,16 @@
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||||
|
|
||||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
@@ -10,10 +18,7 @@ from app.config.settings import settings
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: initialise DB connection pool and agent registry
|
# Startup: initialise DB connection pool
|
||||||
from app.core.agent_registry import registry # noqa: F401 — triggers module load
|
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown: dispose SQLAlchemy connection pool
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
@@ -43,11 +48,10 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
|
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
app.include_router(plans.router, prefix="/api/v1")
|
|
||||||
app.include_router(storage.router, prefix="/api/v1")
|
app.include_router(storage.router, prefix="/api/v1")
|
||||||
app.include_router(vectors.router, prefix="/api/v1")
|
app.include_router(vectors.router, prefix="/api/v1")
|
||||||
app.include_router(backup.router, prefix="/api/v1")
|
app.include_router(backup.router, prefix="/api/v1")
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
|||||||
"write:projects",
|
"write:projects",
|
||||||
"read:notes",
|
"read:notes",
|
||||||
"write:notes",
|
"write:notes",
|
||||||
"read:checkpoints",
|
"read:timelines",
|
||||||
"write:checkpoints",
|
"write:timelines",
|
||||||
"read:calendar",
|
"read:calendar",
|
||||||
"write:calendar",
|
"write:calendar",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ Table inventory:
|
|||||||
plugin_installations — per-user install records
|
plugin_installations — per-user install records
|
||||||
plugin_reviews — admin review decisions
|
plugin_reviews — admin review decisions
|
||||||
revenue_events — Stripe Connect 70/30 split ledger
|
revenue_events — Stripe Connect 70/30 split ledger
|
||||||
|
memory_core — per-user persistent key/value preferences (encrypted)
|
||||||
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
|
memory_proactive — per-user behavioral patterns (encrypted)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -71,9 +75,14 @@ class User(Base):
|
|||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
)
|
)
|
||||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||||
|
# Used to encrypt/decrypt all memory rows for this user.
|
||||||
|
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
@@ -375,3 +384,93 @@ class AgentRunLog(Base):
|
|||||||
foreign_keys="AgentRunLog.agent_id",
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
overlaps="run_logs,local_agent",
|
overlaps="run_logs,local_agent",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCore(Base):
|
||||||
|
"""Per-user persistent key/value preferences, encrypted at rest.
|
||||||
|
|
||||||
|
Examples: preferred_language, timezone, work_style.
|
||||||
|
Decrypted in-memory only using User.encryption_key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_core"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryAssociative(Base):
|
||||||
|
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
|
||||||
|
|
||||||
|
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
|
||||||
|
Tests (SQLite): stored as JSON list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_associative"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
||||||
|
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEpisodic(Base):
|
||||||
|
"""Per-user session summaries, encrypted at rest.
|
||||||
|
|
||||||
|
One row per session interaction; used to recall recent conversations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_episodic"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryProactive(Base):
|
||||||
|
"""Per-user inferred behavioral patterns, encrypted at rest.
|
||||||
|
|
||||||
|
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
|
||||||
|
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_proactive"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
|
||||||
|
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ class AuthTokens(BaseModel):
|
|||||||
class UserProfile(BaseModel):
|
class UserProfile(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
email: str
|
email: str
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
tier: BillingTier
|
tier: BillingTier
|
||||||
|
|
||||||
|
|
||||||
@@ -39,41 +41,13 @@ class ChatContext(BaseModel):
|
|||||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class PlanAction(BaseModel):
|
|
||||||
type: Literal[
|
|
||||||
"create_record",
|
|
||||||
"update_record",
|
|
||||||
"delete_record",
|
|
||||||
"index_document",
|
|
||||||
"send_notification",
|
|
||||||
]
|
|
||||||
table: str | None = None
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
context: ChatContext = Field(default_factory=ChatContext)
|
context: ChatContext = Field(default_factory=ChatContext)
|
||||||
execution_mode: Literal["direct", "plan"] = "direct"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
response: str
|
response: str
|
||||||
actions: list[PlanAction] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plans ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class PlanStep(BaseModel):
|
|
||||||
action: str
|
|
||||||
prompt_template: str | None = None
|
|
||||||
variables: dict[str, Any] | None = None
|
|
||||||
data_from_step: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlan(BaseModel):
|
|
||||||
agent: str
|
|
||||||
steps: list[PlanStep] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Backup ───────────────────────────────────────────────────────────
|
# ── Backup ───────────────────────────────────────────────────────────
|
||||||
@@ -161,6 +135,7 @@ class PluginInstallRequest(BaseModel):
|
|||||||
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||||
|
|
||||||
class WsFrameType(str, Enum):
|
class WsFrameType(str, Enum):
|
||||||
|
# ── v2 frame types (kept for backward compat) ──────────────────────
|
||||||
chat_request = "chat_request"
|
chat_request = "chat_request"
|
||||||
text_chunk = "text_chunk"
|
text_chunk = "text_chunk"
|
||||||
tool_call = "tool_call"
|
tool_call = "tool_call"
|
||||||
@@ -171,6 +146,16 @@ class WsFrameType(str, Enum):
|
|||||||
agent_data = "agent_data"
|
agent_data = "agent_data"
|
||||||
agent_complete = "agent_complete"
|
agent_complete = "agent_complete"
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
|
home_request = "home_request"
|
||||||
|
floating_request = "floating_request"
|
||||||
|
stream_start = "stream_start"
|
||||||
|
stream_text = "stream_text"
|
||||||
|
stream_end = "stream_end"
|
||||||
|
floating_domain = "floating_domain"
|
||||||
|
data_request = "data_request"
|
||||||
|
data_response = "data_response"
|
||||||
|
mutation = "mutation"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -249,6 +234,62 @@ class WsAgentComplete(BaseModel):
|
|||||||
errors: list[str] = Field(default_factory=list)
|
errors: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
|
class WsFloatingScope(BaseModel):
|
||||||
|
"""Scope for a floating request — narrows the agent to a specific entity."""
|
||||||
|
|
||||||
|
type: Literal["task", "project", "note", "timeline"]
|
||||||
|
id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsHomeRequest(BaseModel):
|
||||||
|
"""Client → Server: Home chat message."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
||||||
|
message: str
|
||||||
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsFloatingRequest(BaseModel):
|
||||||
|
"""Client → Server: Floating chat message scoped to an entity."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
||||||
|
message: str
|
||||||
|
scope: WsFloatingScope
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamStart(BaseModel):
|
||||||
|
"""Server → Client: signals start of a streaming response."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamText(BaseModel):
|
||||||
|
"""Server → Client: streamed text token."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
|
||||||
|
request_id: str
|
||||||
|
chunk: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamEnd(BaseModel):
|
||||||
|
"""Server → Client: signals end of a streaming response."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
|
request_id: str
|
||||||
|
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsFloatingDomain(BaseModel):
|
||||||
|
"""Server → Client: domain determined for a floating request."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
|
request_id: str
|
||||||
|
domain: Literal["tasks", "timelines", "notes", "projects"]
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
class AgentCatalogItem(BaseModel):
|
class AgentCatalogItem(BaseModel):
|
||||||
|
|||||||
@@ -8,13 +8,16 @@ services:
|
|||||||
required: false
|
required: false
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
||||||
|
volumes:
|
||||||
|
- copilot_tokens:/root/.config/litellm/github_copilot
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
db:
|
db:
|
||||||
image: postgres:16-alpine
|
image: pgvector/pgvector:pg16
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: postgres
|
||||||
POSTGRES_PASSWORD: postgres
|
POSTGRES_PASSWORD: postgres
|
||||||
@@ -66,3 +69,4 @@ volumes:
|
|||||||
postgres_data:
|
postgres_data:
|
||||||
minio_data:
|
minio_data:
|
||||||
qdrant_data:
|
qdrant_data:
|
||||||
|
copilot_tokens:
|
||||||
|
|||||||
56
logging.conf
Normal file
56
logging.conf
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
[loggers]
|
||||||
|
keys=root,uvicorn,uvicorn.error,uvicorn.access,sqlalchemy,watchfiles
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys=console,file
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys=default
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level=INFO
|
||||||
|
handlers=console,file
|
||||||
|
|
||||||
|
[logger_uvicorn]
|
||||||
|
level=INFO
|
||||||
|
handlers=
|
||||||
|
qualname=uvicorn
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_uvicorn.error]
|
||||||
|
level=INFO
|
||||||
|
handlers=
|
||||||
|
qualname=uvicorn.error
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_uvicorn.access]
|
||||||
|
level=INFO
|
||||||
|
handlers=
|
||||||
|
qualname=uvicorn.access
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level=WARNING
|
||||||
|
handlers=
|
||||||
|
qualname=sqlalchemy
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_watchfiles]
|
||||||
|
level=WARNING
|
||||||
|
handlers=
|
||||||
|
qualname=watchfiles
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class=StreamHandler
|
||||||
|
formatter=default
|
||||||
|
args=(sys.stderr,)
|
||||||
|
|
||||||
|
[handler_file]
|
||||||
|
class=logging.handlers.RotatingFileHandler
|
||||||
|
formatter=default
|
||||||
|
args=('logs/app.log', 'a', 10485760, 5, 'utf-8')
|
||||||
|
|
||||||
|
[formatter_default]
|
||||||
|
format=%(asctime)s %(levelname)s %(name)s: %(message)s
|
||||||
|
datefmt=%Y-%m-%d %H:%M:%S
|
||||||
@@ -3,6 +3,9 @@ uvicorn[standard]>=0.34.0
|
|||||||
gunicorn>=22.0.0
|
gunicorn>=22.0.0
|
||||||
langchain>=0.3.0
|
langchain>=0.3.0
|
||||||
langchain-openai>=0.3.0
|
langchain-openai>=0.3.0
|
||||||
|
langchain-litellm>=0.1.0
|
||||||
|
langgraph>=0.3.0
|
||||||
|
deepagents>=0.4.10
|
||||||
litellm>=1.50.0
|
litellm>=1.50.0
|
||||||
pydantic>=2.10.0
|
pydantic>=2.10.0
|
||||||
pydantic-settings>=2.7.0
|
pydantic-settings>=2.7.0
|
||||||
@@ -25,4 +28,10 @@ moto[s3]>=5.0.0
|
|||||||
pinecone>=5.0.0
|
pinecone>=5.0.0
|
||||||
qdrant-client>=1.7.0
|
qdrant-client>=1.7.0
|
||||||
croniter>=3.0.0
|
croniter>=3.0.0
|
||||||
|
google-api-python-client>=2.130.0
|
||||||
|
google-auth>=2.29.0
|
||||||
|
google-auth-oauthlib>=1.2.0
|
||||||
|
google-auth-httplib2>=0.2.0
|
||||||
|
msal>=1.28.0
|
||||||
|
cryptography>=42.0.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
@@ -129,12 +129,12 @@ _SEED_PLUGINS = [
|
|||||||
Plugin(
|
Plugin(
|
||||||
id="plugin-slack-notify",
|
id="plugin-slack-notify",
|
||||||
name="Slack Notifier",
|
name="Slack Notifier",
|
||||||
description="Post task and checkpoint updates to Slack channels.",
|
description="Post task and timeline updates to Slack channels.",
|
||||||
version="1.2.0",
|
version="1.2.0",
|
||||||
author_name="Adiuva",
|
author_name="Adiuva",
|
||||||
category="communication",
|
category="communication",
|
||||||
price_cents=499,
|
price_cents=499,
|
||||||
permissions=json.dumps(["read:tasks", "read:checkpoints"]),
|
permissions=json.dumps(["read:tasks", "read:timelines"]),
|
||||||
status="approved",
|
status="approved",
|
||||||
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
install_count=0,
|
install_count=0,
|
||||||
|
|||||||
@@ -1,214 +0,0 @@
|
|||||||
"""Unit tests for the agent registry, base classes, and tool loop."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _StubAgent(ChatAgent):
|
|
||||||
"""Minimal concrete agent for testing."""
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "stub"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "A stub agent for tests"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"echo: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
class _AnotherAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "another"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Another stub"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return "another"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _fresh_registry():
|
|
||||||
"""Reset the singleton between tests."""
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
yield
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def reg() -> AgentRegistry:
|
|
||||||
return AgentRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tests ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestRegisterAndGet:
|
|
||||||
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
agent = reg.get("stub")
|
|
||||||
assert isinstance(agent, _StubAgent)
|
|
||||||
|
|
||||||
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError, match="not found"):
|
|
||||||
reg.get("nonexistent")
|
|
||||||
|
|
||||||
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
reg.register(_AnotherAgent)
|
|
||||||
assert reg.get("stub").get_name() == "stub"
|
|
||||||
assert reg.get("another").get_name() == "another"
|
|
||||||
|
|
||||||
|
|
||||||
class TestListAgents:
|
|
||||||
def test_empty(self, reg: AgentRegistry) -> None:
|
|
||||||
assert reg.list_agents() == []
|
|
||||||
|
|
||||||
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
agents = reg.list_agents()
|
|
||||||
assert len(agents) == 1
|
|
||||||
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
|
||||||
|
|
||||||
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
reg.register(_AnotherAgent)
|
|
||||||
names = {a["name"] for a in reg.list_agents()}
|
|
||||||
assert names == {"stub", "another"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestCallAgent:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
result = await reg.call_agent("stub", "hello", {})
|
|
||||||
assert result == "echo: hello"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
await reg.call_agent("nope", "hi", {})
|
|
||||||
|
|
||||||
|
|
||||||
class TestSingleton:
|
|
||||||
def test_singleton_identity(self) -> None:
|
|
||||||
a = AgentRegistry()
|
|
||||||
b = AgentRegistry()
|
|
||||||
assert a is b
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolLoop:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_tool_calls(self) -> None:
|
|
||||||
"""When the LLM responds without tool calls, return content directly."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
ai_msg = MagicMock()
|
|
||||||
ai_msg.content = "final answer"
|
|
||||||
ai_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [])
|
|
||||||
assert result == "final answer"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_call_then_answer(self) -> None:
|
|
||||||
"""LLM requests one tool call, gets result, then answers."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
# First response: tool call
|
|
||||||
tool_call_msg = MagicMock()
|
|
||||||
tool_call_msg.content = ""
|
|
||||||
tool_call_msg.tool_calls = [
|
|
||||||
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Second response: final answer
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "done"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
# Mock tool
|
|
||||||
tool = AsyncMock()
|
|
||||||
tool.name = "my_tool"
|
|
||||||
tool.ainvoke = AsyncMock(return_value="tool_result")
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [tool])
|
|
||||||
assert result == "done"
|
|
||||||
tool.ainvoke.assert_called_once_with({"x": 1})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unknown_tool_handled(self) -> None:
|
|
||||||
"""Unknown tool names produce an error message instead of crashing."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
tool_call_msg = MagicMock()
|
|
||||||
tool_call_msg.content = ""
|
|
||||||
tool_call_msg.tool_calls = [
|
|
||||||
{"id": "call_1", "name": "missing", "args": {}}
|
|
||||||
]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "recovered"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [])
|
|
||||||
assert result == "recovered"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_max_iter_reached(self) -> None:
|
|
||||||
"""When max iterations are exhausted, a final no-tools call is made."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
# Every response requests a tool call
|
|
||||||
loop_msg = MagicMock()
|
|
||||||
loop_msg.content = ""
|
|
||||||
loop_msg.tool_calls = [
|
|
||||||
{"id": "call_x", "name": "t", "args": {}}
|
|
||||||
]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "gave up"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
tool = AsyncMock()
|
|
||||||
tool.name = "t"
|
|
||||||
tool.ainvoke = AsyncMock(return_value="ok")
|
|
||||||
|
|
||||||
llm_with_tools = AsyncMock()
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
|
||||||
assert result == "gave up"
|
|
||||||
assert llm_with_tools.ainvoke.call_count == 2
|
|
||||||
@@ -455,21 +455,232 @@ async def test_run_local_agent_llm_extraction_error():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_cloud_agent_stub_returns_error():
|
async def test_run_cloud_agent_device_offline():
|
||||||
"""Cloud agent stub immediately marks run as error with informative message."""
|
"""Cloud agent aborts immediately when no device is connected."""
|
||||||
config = _make_cloud_config()
|
config = _make_cloud_config()
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = DeviceConnectionManager() # empty — no devices registered
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_no_oauth_token():
|
||||||
|
"""Cloud agent errors when no OAuth token is stored."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = None
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
_, kwargs = mock_finalize.call_args
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
assert kwargs["status"] == "error"
|
||||||
assert len(kwargs["errors"]) == 1
|
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
||||||
assert "gmail" in kwargs["errors"][0].lower()
|
|
||||||
assert "3.6" in kwargs["errors"][0]
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_token_decrypt_failure():
|
||||||
|
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
valid_key = _Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_happy_path_gmail():
|
||||||
|
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
||||||
|
from app.integrations import EmailMessage, encrypt_token
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
|
||||||
|
fernet_key = _Fernet.generate_key().decode()
|
||||||
|
credentials = {
|
||||||
|
"token": "access_abc",
|
||||||
|
"refresh_token": "refresh_xyz",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "cid",
|
||||||
|
"client_secret": "csec",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.provider = "gmail"
|
||||||
|
config.prompt_template = "Extract tasks from this email."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as ms:
|
||||||
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||||
|
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
sample_email = EmailMessage(
|
||||||
|
id="msg001",
|
||||||
|
subject="Action required",
|
||||||
|
sender="boss@company.com",
|
||||||
|
body_text="Please fix the bug by Friday.",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as mock_int_settings, \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
||||||
|
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
||||||
|
patch("app.core.agent_runner.async_session"):
|
||||||
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
|
||||||
|
mock_gmail = AsyncMock()
|
||||||
|
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
||||||
|
mock_gmail.refreshed_credentials = None
|
||||||
|
|
||||||
|
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_gmail):
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_extract.assert_called_once()
|
||||||
|
mock_insert.assert_called_once()
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "success"
|
||||||
|
assert kwargs["items_processed"] == 1
|
||||||
|
assert kwargs["items_created"] == 1
|
||||||
|
assert kwargs["config_type"] == "cloud"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_provider_fetch_error():
|
||||||
|
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
||||||
|
credentials = {"token": "abc"}
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
||||||
|
config.prompt_template = "Extract tasks."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
||||||
|
mock_provider.refreshed_credentials = None
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||||
|
patch("app.core.agent_runner.async_session"):
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||||
|
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||||
|
from app.integrations import EmailMessage, encrypt_token
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
|
||||||
|
fernet_key = _Fernet.generate_key().decode()
|
||||||
|
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
||||||
|
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
||||||
|
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.prompt_template = "Extract tasks."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as ms:
|
||||||
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||||
|
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
||||||
|
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
||||||
|
|
||||||
|
# Track DB writes via mock async_session.
|
||||||
|
mock_cfg_row = MagicMock()
|
||||||
|
mock_cfg_row.oauth_token_encrypted = None
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||||
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
||||||
|
cfg_result = MagicMock()
|
||||||
|
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
||||||
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||||
|
mock_db.commit = AsyncMock()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
||||||
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||||
|
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
||||||
|
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
||||||
|
patch("app.integrations.settings") as mock_int_settings:
|
||||||
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
# The new encrypted token should have been written to the config row.
|
||||||
|
mock_encrypt.assert_called_once_with(fresh_credentials)
|
||||||
|
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finalize_run_updates_cloud_config_last_run_at():
|
||||||
|
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
||||||
|
from app.core.agent_runner import _finalize_run
|
||||||
|
|
||||||
|
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
||||||
|
run_log.id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.last_run_at = None
|
||||||
|
|
||||||
|
cfg_result = MagicMock()
|
||||||
|
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||||
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_db.merge = AsyncMock(return_value=run_log)
|
||||||
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||||
|
mock_db.commit = AsyncMock()
|
||||||
|
|
||||||
|
config_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="success",
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config_id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
# CloudAgentConfig.last_run_at should have been set.
|
||||||
|
assert mock_cfg.last_run_at is not None
|
||||||
|
mock_db.commit.assert_called()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -1,620 +0,0 @@
|
|||||||
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
from app.agents.checkpoint_agent import CheckpointAgent
|
|
||||||
from app.agents.note_agent import NoteAgent
|
|
||||||
from app.agents.project_agent import ProjectAgent
|
|
||||||
from app.agents.task_agent import TaskAgent
|
|
||||||
from app.core.agent_registry import registry
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm(response_text: str) -> MagicMock:
|
|
||||||
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.content = response_text
|
|
||||||
msg.tool_calls = []
|
|
||||||
llm = MagicMock()
|
|
||||||
bound = MagicMock()
|
|
||||||
bound.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
llm.bind_tools = MagicMock(return_value=bound)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm_with_tool_call(
|
|
||||||
tool_name: str, tool_args: dict[str, Any], final_text: str
|
|
||||||
) -> MagicMock:
|
|
||||||
"""Mock LLM that fires one tool call then returns *final_text*."""
|
|
||||||
tool_msg = MagicMock()
|
|
||||||
tool_msg.content = ""
|
|
||||||
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = final_text
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
bound = MagicMock()
|
|
||||||
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=bound)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
# ── Registration ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentRegistration:
|
|
||||||
def test_all_agents_registered(self) -> None:
|
|
||||||
names = {a["name"] for a in registry.list_agents()}
|
|
||||||
assert {
|
|
||||||
"task_agent", "checkpoint_agent", "project_agent", "note_agent"
|
|
||||||
}.issubset(names)
|
|
||||||
|
|
||||||
def test_registry_returns_correct_types(self) -> None:
|
|
||||||
assert isinstance(registry.get("task_agent"), TaskAgent)
|
|
||||||
assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent)
|
|
||||||
assert isinstance(registry.get("project_agent"), ProjectAgent)
|
|
||||||
assert isinstance(registry.get("note_agent"), NoteAgent)
|
|
||||||
|
|
||||||
def test_descriptions_present(self) -> None:
|
|
||||||
for agent_info in registry.list_agents():
|
|
||||||
assert agent_info["description"], f"Empty description: {agent_info['name']}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── TaskAgent ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestTaskAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert TaskAgent().get_name() == "task_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(TaskAgent().get_tools()) == 8
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in TaskAgent().get_tools()}
|
|
||||||
assert names == {
|
|
||||||
"list_tasks",
|
|
||||||
"create_task",
|
|
||||||
"update_task",
|
|
||||||
"delete_task",
|
|
||||||
"list_tasks_due_today",
|
|
||||||
"list_task_comments",
|
|
||||||
"add_task_comment",
|
|
||||||
"delete_task_comment",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_returns_string(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Task created.")
|
|
||||||
result = await TaskAgent().handle("create a task", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Here are your tasks.")
|
|
||||||
result = await TaskAgent().handle("list my tasks", {})
|
|
||||||
assert result == "Here are your tasks."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_task_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_task",
|
|
||||||
{"title": "Buy groceries", "priority": "low"},
|
|
||||||
"Task 'Buy groceries' created.",
|
|
||||||
)
|
|
||||||
result = await TaskAgent().handle("add a grocery task", {})
|
|
||||||
assert result == "Task 'Buy groceries' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await TaskAgent().handle("help", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_rich_context(self) -> None:
|
|
||||||
context = {
|
|
||||||
"user_profile": {"id": "u1", "tier": "pro"},
|
|
||||||
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
|
||||||
}
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Tasks listed.")
|
|
||||||
result = await TaskAgent().handle("show tasks", context)
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTaskAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_defaults(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks
|
|
||||||
result = await list_tasks.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "tasks"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_with_status_filter(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks
|
|
||||||
result = await list_tasks.ainvoke({"status": "done"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["filters"]["status"] == "done"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_task_defaults(self) -> None:
|
|
||||||
from app.agents.task_agent import create_task
|
|
||||||
result = await create_task.ainvoke({"title": "Test task"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["table"] == "tasks"
|
|
||||||
assert data["data"]["title"] == "Test task"
|
|
||||||
assert data["data"]["status"] == "todo"
|
|
||||||
assert data["data"]["priority"] == "medium"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_task_with_all_fields(self) -> None:
|
|
||||||
from app.agents.task_agent import create_task
|
|
||||||
result = await create_task.ainvoke({
|
|
||||||
"title": "Deploy",
|
|
||||||
"priority": "high",
|
|
||||||
"status": "in_progress",
|
|
||||||
"project_id": "p1",
|
|
||||||
"is_ai_suggested": 1,
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["priority"] == "high"
|
|
||||||
assert data["data"]["status"] == "in_progress"
|
|
||||||
assert data["data"]["projectId"] == "p1"
|
|
||||||
assert data["data"]["isAiSuggested"] == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_task_with_status(self) -> None:
|
|
||||||
from app.agents.task_agent import update_task
|
|
||||||
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "update_record"
|
|
||||||
assert data["data"]["id"] == "t1"
|
|
||||||
assert data["data"]["updates"]["status"] == "done"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_task_empty_updates(self) -> None:
|
|
||||||
from app.agents.task_agent import update_task
|
|
||||||
result = await update_task.ainvoke({"task_id": "t1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_task(self) -> None:
|
|
||||||
from app.agents.task_agent import delete_task
|
|
||||||
result = await delete_task.ainvoke({"task_id": "t1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["table"] == "tasks"
|
|
||||||
assert data["data"]["id"] == "t1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_due_today(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks_due_today
|
|
||||||
result = await list_tasks_due_today.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list_due_today"
|
|
||||||
assert data["table"] == "tasks"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_task_comments(self) -> None:
|
|
||||||
from app.agents.task_agent import list_task_comments
|
|
||||||
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "taskComments"
|
|
||||||
assert data["filters"]["taskId"] == "t1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_task_comment(self) -> None:
|
|
||||||
from app.agents.task_agent import add_task_comment
|
|
||||||
result = await add_task_comment.ainvoke({
|
|
||||||
"task_id": "t1",
|
|
||||||
"author": "Alice",
|
|
||||||
"content": "Looks good!",
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["table"] == "taskComments"
|
|
||||||
assert data["data"]["taskId"] == "t1"
|
|
||||||
assert data["data"]["author"] == "Alice"
|
|
||||||
assert data["data"]["content"] == "Looks good!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_task_comment(self) -> None:
|
|
||||||
from app.agents.task_agent import delete_task_comment
|
|
||||||
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["table"] == "taskComments"
|
|
||||||
assert data["data"]["id"] == "c1"
|
|
||||||
|
|
||||||
|
|
||||||
# ── CheckpointAgent ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestCheckpointAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert CheckpointAgent().get_name() == "checkpoint_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(CheckpointAgent().get_tools()) == 4
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in CheckpointAgent().get_tools()}
|
|
||||||
assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("No checkpoints found.")
|
|
||||||
result = await CheckpointAgent().handle("list checkpoints", {})
|
|
||||||
assert result == "No checkpoints found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_checkpoint",
|
|
||||||
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
|
||||||
"Checkpoint 'MVP Launch' created.",
|
|
||||||
)
|
|
||||||
result = await CheckpointAgent().handle("add MVP checkpoint", {})
|
|
||||||
assert result == "Checkpoint 'MVP Launch' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await CheckpointAgent().handle("show milestones", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCheckpointAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_checkpoints_no_project(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import list_checkpoints
|
|
||||||
result = await list_checkpoints.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "checkpoints"
|
|
||||||
assert data["filters"]["projectId"] is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_checkpoints_with_project(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import list_checkpoints
|
|
||||||
result = await list_checkpoints.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["filters"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_checkpoint(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import create_checkpoint
|
|
||||||
result = await create_checkpoint.ainvoke({
|
|
||||||
"project_id": "p1",
|
|
||||||
"title": "Beta release",
|
|
||||||
"date": 1700000000000,
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["table"] == "checkpoints"
|
|
||||||
assert data["data"]["projectId"] == "p1"
|
|
||||||
assert data["data"]["title"] == "Beta release"
|
|
||||||
assert data["data"]["date"] == 1700000000000
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_checkpoint_ai_suggested(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import create_checkpoint
|
|
||||||
result = await create_checkpoint.ainvoke({
|
|
||||||
"project_id": "p1",
|
|
||||||
"title": "Review",
|
|
||||||
"date": 1700000000000,
|
|
||||||
"is_ai_suggested": 1,
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["isAiSuggested"] == 1
|
|
||||||
assert data["data"]["isApproved"] == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_checkpoint_approve(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import update_checkpoint
|
|
||||||
result = await update_checkpoint.ainvoke({
|
|
||||||
"checkpoint_id": "c1",
|
|
||||||
"is_approved": 1,
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "update_record"
|
|
||||||
assert data["data"]["id"] == "c1"
|
|
||||||
assert data["data"]["updates"]["isApproved"] == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_checkpoint_empty_updates(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import update_checkpoint
|
|
||||||
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_checkpoint(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import delete_checkpoint
|
|
||||||
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["table"] == "checkpoints"
|
|
||||||
assert data["data"]["id"] == "c1"
|
|
||||||
|
|
||||||
|
|
||||||
# ── ProjectAgent ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert ProjectAgent().get_name() == "project_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(ProjectAgent().get_tools()) == 6
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in ProjectAgent().get_tools()}
|
|
||||||
assert names == {
|
|
||||||
"list_projects",
|
|
||||||
"list_all_projects",
|
|
||||||
"get_project",
|
|
||||||
"create_project",
|
|
||||||
"update_project",
|
|
||||||
"delete_project",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
|
||||||
result = await ProjectAgent().handle("show my projects", {})
|
|
||||||
assert result == "Project Alpha is active."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_project_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_project",
|
|
||||||
{"name": "Pippo"},
|
|
||||||
"Project 'Pippo' created.",
|
|
||||||
)
|
|
||||||
result = await ProjectAgent().handle("create project Pippo", {})
|
|
||||||
assert result == "Project 'Pippo' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await ProjectAgent().handle("archive old project", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_projects_defaults(self) -> None:
|
|
||||||
from app.agents.project_agent import list_projects
|
|
||||||
result = await list_projects.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "projects"
|
|
||||||
assert data["filters"]["includeArchived"] is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_projects_include_archived(self) -> None:
|
|
||||||
from app.agents.project_agent import list_projects
|
|
||||||
result = await list_projects.ainvoke({"include_archived": 1})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["filters"]["includeArchived"] is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_all_projects(self) -> None:
|
|
||||||
from app.agents.project_agent import list_all_projects
|
|
||||||
result = await list_all_projects.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list_all"
|
|
||||||
assert data["table"] == "projects"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_project(self) -> None:
|
|
||||||
from app.agents.project_agent import get_project
|
|
||||||
result = await get_project.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "get"
|
|
||||||
assert data["table"] == "projects"
|
|
||||||
assert data["data"]["id"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_project_name_only(self) -> None:
|
|
||||||
from app.agents.project_agent import create_project
|
|
||||||
result = await create_project.ainvoke({"name": "Alpha"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["data"]["name"] == "Alpha"
|
|
||||||
assert data["data"]["clientId"] is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_project_with_client(self) -> None:
|
|
||||||
from app.agents.project_agent import create_project
|
|
||||||
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["clientId"] == "cl1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_project_archive(self) -> None:
|
|
||||||
from app.agents.project_agent import update_project
|
|
||||||
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "update_record"
|
|
||||||
assert data["data"]["id"] == "p1"
|
|
||||||
assert data["data"]["updates"]["status"] == "archived"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_project_empty_updates(self) -> None:
|
|
||||||
from app.agents.project_agent import update_project
|
|
||||||
result = await update_project.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_project(self) -> None:
|
|
||||||
from app.agents.project_agent import delete_project
|
|
||||||
result = await delete_project.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["data"]["id"] == "p1"
|
|
||||||
|
|
||||||
|
|
||||||
# ── NoteAgent ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoteAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert NoteAgent().get_name() == "note_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(NoteAgent().get_tools()) == 5
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in NoteAgent().get_tools()}
|
|
||||||
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Note created.")
|
|
||||||
result = await NoteAgent().handle("create a note", {})
|
|
||||||
assert result == "Note created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_note_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_note",
|
|
||||||
{"title": "Daily log", "content": "# Today\nAll good."},
|
|
||||||
"Note 'Daily log' created.",
|
|
||||||
)
|
|
||||||
result = await NoteAgent().handle("log today's progress", {})
|
|
||||||
assert result == "Note 'Daily log' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await NoteAgent().handle("show notes", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoteAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_notes_no_project(self) -> None:
|
|
||||||
from app.agents.note_agent import list_notes
|
|
||||||
result = await list_notes.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "notes"
|
|
||||||
assert data["filters"]["projectId"] is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_notes_with_project(self) -> None:
|
|
||||||
from app.agents.note_agent import list_notes
|
|
||||||
result = await list_notes.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["filters"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_note(self) -> None:
|
|
||||||
from app.agents.note_agent import get_note
|
|
||||||
result = await get_note.ainvoke({"note_id": "n1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "get"
|
|
||||||
assert data["table"] == "notes"
|
|
||||||
assert data["data"]["id"] == "n1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_note_minimal(self) -> None:
|
|
||||||
from app.agents.note_agent import create_note
|
|
||||||
result = await create_note.ainvoke({
|
|
||||||
"title": "Daily log",
|
|
||||||
"content": "# Today\nAll good.",
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["table"] == "notes"
|
|
||||||
assert data["data"]["title"] == "Daily log"
|
|
||||||
assert data["data"]["content"] == "# Today\nAll good."
|
|
||||||
assert data["data"]["projectId"] is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_note_with_project(self) -> None:
|
|
||||||
from app.agents.note_agent import create_note
|
|
||||||
result = await create_note.ainvoke({
|
|
||||||
"title": "Sprint notes",
|
|
||||||
"content": "## Sprint 1",
|
|
||||||
"project_id": "p1",
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_note_content_only(self) -> None:
|
|
||||||
from app.agents.note_agent import update_note
|
|
||||||
result = await update_note.ainvoke({
|
|
||||||
"note_id": "n1",
|
|
||||||
"content": "# Updated content",
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "update_record"
|
|
||||||
assert data["data"]["id"] == "n1"
|
|
||||||
assert data["data"]["updates"]["content"] == "# Updated content"
|
|
||||||
assert "title" not in data["data"]["updates"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_note_empty_updates(self) -> None:
|
|
||||||
from app.agents.note_agent import update_note
|
|
||||||
result = await update_note.ainvoke({"note_id": "n1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_note(self) -> None:
|
|
||||||
from app.agents.note_agent import delete_note
|
|
||||||
result = await delete_note.ainvoke({"note_id": "n1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["table"] == "notes"
|
|
||||||
assert data["data"]["id"] == "n1"
|
|
||||||
@@ -1,286 +0,0 @@
|
|||||||
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.execution_plan import (
|
|
||||||
ExecutionPlanBuilder,
|
|
||||||
PlanCache,
|
|
||||||
PromptTemplateRegistry,
|
|
||||||
plan_cache,
|
|
||||||
template_registry,
|
|
||||||
)
|
|
||||||
from app.schemas import ExecutionPlan
|
|
||||||
|
|
||||||
|
|
||||||
# ── PromptTemplateRegistry ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestPromptTemplateRegistry:
|
|
||||||
def test_register_and_get(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_foo", "You are a foo agent.")
|
|
||||||
assert reg.get("tpl_foo") == "You are a foo agent."
|
|
||||||
|
|
||||||
def test_get_unknown_raises_key_error(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
with pytest.raises(KeyError, match="tpl_missing"):
|
|
||||||
reg.get("tpl_missing")
|
|
||||||
|
|
||||||
def test_has_returns_true_for_registered(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_x", "prompt text")
|
|
||||||
assert reg.has("tpl_x") is True
|
|
||||||
|
|
||||||
def test_has_returns_false_for_unregistered(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
assert reg.has("tpl_missing") is False
|
|
||||||
|
|
||||||
def test_list_ids_returns_all_registered_ids(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_a", "a")
|
|
||||||
reg.register("tpl_b", "b")
|
|
||||||
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
|
|
||||||
|
|
||||||
def test_list_ids_does_not_return_prompt_text(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_secret", "top secret prompt")
|
|
||||||
ids = reg.list_ids()
|
|
||||||
assert "top secret prompt" not in ids
|
|
||||||
|
|
||||||
def test_overwrite_existing_template(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_x", "v1")
|
|
||||||
reg.register("tpl_x", "v2")
|
|
||||||
assert reg.get("tpl_x") == "v2"
|
|
||||||
|
|
||||||
def test_empty_registry_has_no_ids(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
assert reg.list_ids() == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestExecutionPlanBuilder:
|
|
||||||
def test_builds_empty_plan(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").build()
|
|
||||||
assert plan.agent == "task_agent"
|
|
||||||
assert plan.steps == []
|
|
||||||
|
|
||||||
def test_add_step_basic(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("create_task", {"priority": "high"})
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert len(plan.steps) == 1
|
|
||||||
assert plan.steps[0].action == "create_task"
|
|
||||||
assert plan.steps[0].variables == {"priority": "high"}
|
|
||||||
assert plan.steps[0].prompt_template is None
|
|
||||||
assert plan.steps[0].data_from_step is None
|
|
||||||
|
|
||||||
def test_add_step_no_params(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
|
|
||||||
assert plan.steps[0].variables is None
|
|
||||||
|
|
||||||
def test_add_llm_step(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_llm_step("tpl_task_default", {"message": "hi"})
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[0].action == "llm"
|
|
||||||
assert plan.steps[0].prompt_template == "tpl_task_default"
|
|
||||||
assert plan.steps[0].variables == {"message": "hi"}
|
|
||||||
|
|
||||||
def test_add_llm_step_no_variables(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
|
|
||||||
assert plan.steps[0].variables is None
|
|
||||||
|
|
||||||
def test_add_data_step(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("fetch_data")
|
|
||||||
.add_data_step("transform", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[1].action == "transform"
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_fluent_chaining_returns_builder(self) -> None:
|
|
||||||
builder = ExecutionPlanBuilder("analytics_agent")
|
|
||||||
result = builder.add_step("a")
|
|
||||||
assert result is builder
|
|
||||||
|
|
||||||
def test_fluent_chain_multiple_steps(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("analytics_agent")
|
|
||||||
.add_llm_step("tpl_analytics_default")
|
|
||||||
.add_step("format_output")
|
|
||||||
.add_data_step("store", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert len(plan.steps) == 3
|
|
||||||
|
|
||||||
def test_build_validates_data_from_step_out_of_range(self) -> None:
|
|
||||||
with pytest.raises(ValueError, match="data_from_step"):
|
|
||||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
|
|
||||||
|
|
||||||
def test_build_validates_data_from_step_self_reference(self) -> None:
|
|
||||||
"""data_from_step=0 on the first step (index 0) is invalid."""
|
|
||||||
with pytest.raises(ValueError, match="data_from_step"):
|
|
||||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
|
|
||||||
|
|
||||||
def test_build_validates_data_from_step_negative(self) -> None:
|
|
||||||
with pytest.raises(ValueError, match="data_from_step"):
|
|
||||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
|
|
||||||
|
|
||||||
def test_valid_data_from_step_at_index_two(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("step0")
|
|
||||||
.add_step("step1")
|
|
||||||
.add_data_step("step2", data_from_step=1)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[2].data_from_step == 1
|
|
||||||
|
|
||||||
def test_data_from_step_zero_valid_at_index_one(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("step0")
|
|
||||||
.add_data_step("step1", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_build_returns_new_plan_each_call(self) -> None:
|
|
||||||
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
|
|
||||||
plan1 = builder.build()
|
|
||||||
plan2 = builder.build()
|
|
||||||
assert plan1 is not plan2
|
|
||||||
assert plan1.steps == plan2.steps
|
|
||||||
|
|
||||||
def test_plan_is_execution_plan_instance(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").build()
|
|
||||||
assert isinstance(plan, ExecutionPlan)
|
|
||||||
|
|
||||||
|
|
||||||
# ── PlanCache ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestPlanCache:
|
|
||||||
def _plan(self, agent: str = "a") -> ExecutionPlan:
|
|
||||||
return ExecutionPlanBuilder(agent).build()
|
|
||||||
|
|
||||||
def test_cache_and_get(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
plan = self._plan()
|
|
||||||
cache.cache_plan("key1", plan)
|
|
||||||
assert cache.get_plan("key1") is plan
|
|
||||||
|
|
||||||
def test_get_missing_returns_none(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
assert cache.get_plan("nonexistent") is None
|
|
||||||
|
|
||||||
def test_get_all_playbooks_empty(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
assert cache.get_all_playbooks() == []
|
|
||||||
|
|
||||||
def test_get_all_playbooks_returns_all_stored(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
p1, p2 = self._plan("a"), self._plan("b")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k2", p2)
|
|
||||||
playbooks = cache.get_all_playbooks()
|
|
||||||
assert len(playbooks) == 2
|
|
||||||
assert p1 in playbooks
|
|
||||||
assert p2 in playbooks
|
|
||||||
|
|
||||||
def test_lru_evicts_oldest_entry(self) -> None:
|
|
||||||
cache = PlanCache(maxsize=2)
|
|
||||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k2", p2)
|
|
||||||
cache.cache_plan("k3", p3) # k1 should be evicted
|
|
||||||
assert cache.get_plan("k1") is None
|
|
||||||
assert cache.get_plan("k2") is p2
|
|
||||||
assert cache.get_plan("k3") is p3
|
|
||||||
|
|
||||||
def test_lru_access_updates_recency(self) -> None:
|
|
||||||
cache = PlanCache(maxsize=2)
|
|
||||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k2", p2)
|
|
||||||
cache.get_plan("k1") # k1 is now most-recently used
|
|
||||||
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
|
|
||||||
assert cache.get_plan("k1") is p1
|
|
||||||
assert cache.get_plan("k2") is None
|
|
||||||
assert cache.get_plan("k3") is p3
|
|
||||||
|
|
||||||
def test_overwrite_existing_key(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
p1, p2 = self._plan("a"), self._plan("b")
|
|
||||||
cache.cache_plan("same_key", p1)
|
|
||||||
cache.cache_plan("same_key", p2)
|
|
||||||
assert cache.get_plan("same_key") is p2
|
|
||||||
assert len(cache.get_all_playbooks()) == 1
|
|
||||||
|
|
||||||
def test_overwrite_does_not_consume_capacity(self) -> None:
|
|
||||||
cache = PlanCache(maxsize=2)
|
|
||||||
p1, p2 = self._plan("a"), self._plan("b")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k1", p2) # overwrite, not a new slot
|
|
||||||
cache.cache_plan("k2", p1) # should fit without eviction
|
|
||||||
assert cache.get_plan("k1") is p2
|
|
||||||
assert cache.get_plan("k2") is p1
|
|
||||||
|
|
||||||
|
|
||||||
# ── Module-level singletons ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestModuleSingletons:
|
|
||||||
def test_template_registry_has_all_agent_defaults(self) -> None:
|
|
||||||
for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"):
|
|
||||||
assert template_registry.has(f"tpl_{agent}_default"), (
|
|
||||||
f"Missing template: tpl_{agent}_default"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_template_registry_has_operation_templates(self) -> None:
|
|
||||||
assert template_registry.has("tpl_task_extract_from_project")
|
|
||||||
assert template_registry.has("tpl_note_weekly_summary")
|
|
||||||
|
|
||||||
def test_template_registry_get_returns_non_empty_string(self) -> None:
|
|
||||||
text = template_registry.get("tpl_task_agent_default")
|
|
||||||
assert isinstance(text, str)
|
|
||||||
assert len(text) > 0
|
|
||||||
|
|
||||||
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
|
|
||||||
assert len(plan_cache.get_all_playbooks()) >= 2
|
|
||||||
|
|
||||||
def test_playbook_create_tasks_from_project(self) -> None:
|
|
||||||
plan = plan_cache.get_plan("create_tasks_from_project")
|
|
||||||
assert plan is not None
|
|
||||||
assert plan.agent == "project_agent"
|
|
||||||
assert len(plan.steps) == 2
|
|
||||||
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_playbook_generate_weekly_note(self) -> None:
|
|
||||||
plan = plan_cache.get_plan("generate_weekly_note")
|
|
||||||
assert plan is not None
|
|
||||||
assert plan.agent == "note_agent"
|
|
||||||
assert len(plan.steps) == 2
|
|
||||||
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
|
|
||||||
"""Plans must not embed prompt text — only template IDs."""
|
|
||||||
for plan in plan_cache.get_all_playbooks():
|
|
||||||
for step in plan.steps:
|
|
||||||
if step.prompt_template is not None:
|
|
||||||
assert step.prompt_template.startswith("tpl_"), (
|
|
||||||
f"prompt_template looks like raw text: {step.prompt_template!r}"
|
|
||||||
)
|
|
||||||
729
tests/test_integrations.py
Normal file
729
tests/test_integrations.py
Normal file
@@ -0,0 +1,729 @@
|
|||||||
|
"""Tests for Step 3.6: cloud provider integration clients.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
Unit \u2014 app/integrations/__init__.py:
|
||||||
|
- encrypt_token / decrypt_token round-trip
|
||||||
|
- decrypt_token raises ValueError on invalid ciphertext
|
||||||
|
- encrypt_token raises ValueError on empty/non-dict input
|
||||||
|
- _get_fernet raises RuntimeError when OAUTH_ENCRYPTION_KEY not set
|
||||||
|
- get_provider returns GmailClient for 'gmail'
|
||||||
|
- get_provider returns MSGraphClient for 'outlook' and 'teams'
|
||||||
|
- get_provider raises ValueError for unknown provider
|
||||||
|
|
||||||
|
Unit \u2014 app/integrations/gmail.py:
|
||||||
|
- _build_gmail_query with no filter returns empty string
|
||||||
|
- _build_gmail_query with labels builds label: expr
|
||||||
|
- _build_gmail_query with senders builds from: expr
|
||||||
|
- _build_gmail_query with date_range builds after:/before: exprs
|
||||||
|
- _build_gmail_query since overrides date_range.from when more recent
|
||||||
|
- _build_gmail_query date_range.from overrides since when more recent
|
||||||
|
- _parse_body extracts text/plain part
|
||||||
|
- _parse_body extracts text/html part (stripped)
|
||||||
|
- _parse_body recurses into multipart, prefers text/plain
|
||||||
|
- GmailClient.fetch_messages: happy path with mocked service
|
||||||
|
- GmailClient.fetch_messages: no messages returns empty list
|
||||||
|
- GmailClient.fetch_messages: HTTP error on messages.list raises RuntimeError
|
||||||
|
- GmailClient.refreshed_credentials: None when token unchanged
|
||||||
|
- GmailClient.refreshed_credentials: returns dict when token changes
|
||||||
|
|
||||||
|
Unit \u2014 app/integrations/ms_graph.py:
|
||||||
|
- _build_email_filter with no filter returns empty string
|
||||||
|
- _build_email_filter with senders builds OData from clause
|
||||||
|
- _build_email_filter with since builds receivedDateTime ge clause
|
||||||
|
- MSGraphClient.fetch_emails: happy path with mocked httpx
|
||||||
|
- MSGraphClient.fetch_emails: 401 triggers token refresh and retries
|
||||||
|
- MSGraphClient.fetch_messages: happy path with mocked httpx
|
||||||
|
- MSGraphClient.fetch_messages: 403 from getAllMessages degrades gracefully
|
||||||
|
- MSGraphClient.refreshed_credentials: None when token unchanged
|
||||||
|
- MSGraphClient._refresh_access_token: MSAL error raises RuntimeError
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.integrations import (
|
||||||
|
ChatMessage,
|
||||||
|
EmailMessage,
|
||||||
|
decrypt_token,
|
||||||
|
encrypt_token,
|
||||||
|
get_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# Helpers
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
_FERNET_KEY = "eW91LXNob3VsZC1ub3QtdXNlLXRoaXMta2V5LWluLXByb2Q="
|
||||||
|
# ^ 32-char URL-safe base64 (generated for tests only; not a real Fernet key length,
|
||||||
|
# so we generate a proper one below)
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet as _Fernet # noqa: E402
|
||||||
|
|
||||||
|
_VALID_KEY = _Fernet.generate_key().decode("utf-8")
|
||||||
|
|
||||||
|
_TOKEN_DICT = {
|
||||||
|
"token": "access_abc",
|
||||||
|
"refresh_token": "refresh_xyz",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "client_id_123",
|
||||||
|
"client_secret": "client_secret_456",
|
||||||
|
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_MS_TOKEN_DICT = {
|
||||||
|
"access_token": "ms_access_abc",
|
||||||
|
"refresh_token": "ms_refresh_xyz",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "Mail.Read offline_access",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# encrypt_token / decrypt_token
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenEncryption:
|
||||||
|
"""encrypt_token / decrypt_token round-trip tests."""
|
||||||
|
|
||||||
|
def test_round_trip(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
encrypted = encrypt_token(_TOKEN_DICT)
|
||||||
|
assert isinstance(encrypted, str)
|
||||||
|
assert encrypted != json.dumps(_TOKEN_DICT) # must be ciphertext, not plaintext
|
||||||
|
recovered = decrypt_token(encrypted)
|
||||||
|
assert recovered == _TOKEN_DICT
|
||||||
|
|
||||||
|
def test_decrypt_invalid_ciphertext_raises_value_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||||
|
decrypt_token("this-is-not-valid-fernet-ciphertext")
|
||||||
|
|
||||||
|
def test_decrypt_wrong_key_raises_value_error(self):
|
||||||
|
"""Decrypting with a different key must fail with ValueError."""
|
||||||
|
other_key = _Fernet.generate_key().decode("utf-8")
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
encrypted = encrypt_token(_TOKEN_DICT)
|
||||||
|
with patch("app.integrations.settings") as mock_settings2:
|
||||||
|
mock_settings2.OAUTH_ENCRYPTION_KEY = other_key
|
||||||
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||||
|
decrypt_token(encrypted)
|
||||||
|
|
||||||
|
def test_encrypt_empty_dict_raises_value_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
with pytest.raises(ValueError, match="non-empty dict"):
|
||||||
|
encrypt_token({})
|
||||||
|
|
||||||
|
def test_encrypt_non_dict_raises_value_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
with pytest.raises(ValueError, match="non-empty dict"):
|
||||||
|
encrypt_token("not-a-dict") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_missing_key_raises_runtime_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = ""
|
||||||
|
with pytest.raises(RuntimeError, match="OAUTH_ENCRYPTION_KEY"):
|
||||||
|
encrypt_token(_TOKEN_DICT)
|
||||||
|
|
||||||
|
def test_email_message_as_text(self):
|
||||||
|
msg = EmailMessage(
|
||||||
|
id="m1",
|
||||||
|
subject="Hello",
|
||||||
|
sender="alice@example.com",
|
||||||
|
body_text="Test body",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
text = msg.as_text
|
||||||
|
assert "From: alice@example.com" in text
|
||||||
|
assert "Subject: Hello" in text
|
||||||
|
assert "Test body" in text
|
||||||
|
|
||||||
|
def test_chat_message_as_text(self):
|
||||||
|
msg = ChatMessage(
|
||||||
|
id="c1",
|
||||||
|
content="Buy milk",
|
||||||
|
sender="bob",
|
||||||
|
channel="general",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
text = msg.as_text
|
||||||
|
assert "From: bob" in text
|
||||||
|
assert "channel: general" in text
|
||||||
|
assert "Buy milk" in text
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# get_provider factory
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetProvider:
|
||||||
|
def test_gmail_returns_gmail_client(self):
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
|
||||||
|
client = get_provider("gmail", _TOKEN_DICT)
|
||||||
|
assert isinstance(client, GmailClient)
|
||||||
|
|
||||||
|
def test_outlook_returns_ms_graph_client(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
client = get_provider("outlook", _MS_TOKEN_DICT)
|
||||||
|
assert isinstance(client, MSGraphClient)
|
||||||
|
|
||||||
|
def test_teams_returns_ms_graph_client(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
client = get_provider("teams", _MS_TOKEN_DICT)
|
||||||
|
assert isinstance(client, MSGraphClient)
|
||||||
|
|
||||||
|
def test_unknown_provider_raises_value_error(self):
|
||||||
|
with pytest.raises(ValueError, match="Unknown cloud provider"):
|
||||||
|
get_provider("slack", {})
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# Gmail client \u2014 query builder
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildGmailQuery:
|
||||||
|
"""Unit tests for gmail._build_gmail_query."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
from app.integrations.gmail import _build_gmail_query
|
||||||
|
self._fn = _build_gmail_query
|
||||||
|
|
||||||
|
def test_empty_returns_empty_string(self):
|
||||||
|
assert self._fn(None, None) == ""
|
||||||
|
|
||||||
|
def test_single_label(self):
|
||||||
|
q = self._fn({"labels": ["INBOX"]}, None)
|
||||||
|
assert "label:INBOX" in q
|
||||||
|
|
||||||
|
def test_multiple_labels_joined_with_or(self):
|
||||||
|
q = self._fn({"labels": ["INBOX", "work"]}, None)
|
||||||
|
assert "label:INBOX OR label:work" in q
|
||||||
|
|
||||||
|
def test_senders(self):
|
||||||
|
q = self._fn({"senders": ["alice@example.com"]}, None)
|
||||||
|
assert "from:alice@example.com" in q
|
||||||
|
|
||||||
|
def test_date_range_from(self):
|
||||||
|
q = self._fn({"date_range": {"from": "2025-01-15"}}, None)
|
||||||
|
assert "after:2025/01/15" in q
|
||||||
|
|
||||||
|
def test_date_range_to(self):
|
||||||
|
q = self._fn({"date_range": {"to": "2025-03-01"}}, None)
|
||||||
|
assert "before:2025/03/01" in q
|
||||||
|
|
||||||
|
def test_since_overrides_earlier_date_range_from(self):
|
||||||
|
"""since=Feb is more recent than date_range.from=Jan, so after: should be Feb."""
|
||||||
|
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
||||||
|
q = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
||||||
|
assert "after:2025/02/01" in q
|
||||||
|
assert "after:2025/01/01" not in q
|
||||||
|
|
||||||
|
def test_date_range_from_overrides_earlier_since(self):
|
||||||
|
"""date_range.from=Feb is more recent than since=Jan, so after: should be Feb."""
|
||||||
|
since = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||||
|
q = self._fn({"date_range": {"from": "2025-02-01"}}, since)
|
||||||
|
assert "after:2025/02/01" in q
|
||||||
|
|
||||||
|
def test_invalid_date_ignored(self):
|
||||||
|
"""An invalid date string in filter_config must not raise, just be skipped."""
|
||||||
|
q = self._fn({"date_range": {"from": "not-a-date"}}, None)
|
||||||
|
assert "after:" not in q
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# Gmail client \u2014 body parsing
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseBody:
|
||||||
|
"""Unit tests for gmail._parse_body."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
from app.integrations.gmail import _parse_body
|
||||||
|
self._fn = _parse_body
|
||||||
|
|
||||||
|
def _encode(self, text: str) -> str:
|
||||||
|
import base64
|
||||||
|
return base64.urlsafe_b64encode(text.encode()).decode()
|
||||||
|
|
||||||
|
def test_text_plain_extracted(self):
|
||||||
|
payload = {
|
||||||
|
"mimeType": "text/plain",
|
||||||
|
"body": {"data": self._encode("Hello world")},
|
||||||
|
}
|
||||||
|
assert self._fn(payload) == "Hello world"
|
||||||
|
|
||||||
|
def test_text_html_stripped(self):
|
||||||
|
payload = {
|
||||||
|
"mimeType": "text/html",
|
||||||
|
"body": {"data": self._encode("<p>Hello <b>world</b></p>")},
|
||||||
|
}
|
||||||
|
result = self._fn(payload)
|
||||||
|
assert "Hello" in result
|
||||||
|
assert "<p>" not in result
|
||||||
|
|
||||||
|
def test_multipart_prefers_plain_over_html(self):
|
||||||
|
plain_data = self._encode("Plain text")
|
||||||
|
html_data = self._encode("<p>HTML text</p>")
|
||||||
|
payload = {
|
||||||
|
"mimeType": "multipart/alternative",
|
||||||
|
"body": {},
|
||||||
|
"parts": [
|
||||||
|
{"mimeType": "text/html", "body": {"data": html_data}},
|
||||||
|
{"mimeType": "text/plain", "body": {"data": plain_data}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
result = self._fn(payload)
|
||||||
|
assert result == "Plain text"
|
||||||
|
|
||||||
|
def test_empty_payload_returns_empty_string(self):
|
||||||
|
assert self._fn({}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# GmailClient.fetch_messages
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
def _make_gmail_message(
|
||||||
|
msg_id: str = "msg001",
|
||||||
|
subject: str = "Test email",
|
||||||
|
sender: str = "alice@example.com",
|
||||||
|
body_text: str = "Hello world",
|
||||||
|
date: str = "Mon, 01 Jan 2025 10:00:00 +0000",
|
||||||
|
) -> dict:
|
||||||
|
"""Build a minimal Gmail API message response dict."""
|
||||||
|
import base64
|
||||||
|
body_data = base64.urlsafe_b64encode(body_text.encode()).decode()
|
||||||
|
return {
|
||||||
|
"id": msg_id,
|
||||||
|
"labelIds": ["INBOX"],
|
||||||
|
"payload": {
|
||||||
|
"mimeType": "text/plain",
|
||||||
|
"headers": [
|
||||||
|
{"name": "Subject", "value": subject},
|
||||||
|
{"name": "From", "value": sender},
|
||||||
|
{"name": "Date", "value": date},
|
||||||
|
],
|
||||||
|
"body": {"data": body_data},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestGmailClientFetchMessages:
|
||||||
|
"""GmailClient.fetch_messages tests with mocked Google API."""
|
||||||
|
|
||||||
|
def _make_client(self) -> "GmailClient":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(_TOKEN_DICT)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_returns_email_messages(self):
|
||||||
|
client = self._make_client()
|
||||||
|
msg = _make_gmail_message()
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_users = mock_service.users.return_value
|
||||||
|
mock_messages = mock_users.messages.return_value
|
||||||
|
mock_messages.list.return_value.execute.return_value = {
|
||||||
|
"messages": [{"id": "msg001"}]
|
||||||
|
}
|
||||||
|
mock_messages.get.return_value.execute.return_value = msg
|
||||||
|
|
||||||
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||||
|
# Simulate to_thread running the sync function and returning results.
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
mock_thread.side_effect = fake_to_thread
|
||||||
|
|
||||||
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||||
|
patch("google.auth.transport.requests.Request"), \
|
||||||
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].subject == "Test email"
|
||||||
|
assert results[0].sender == "alice@example.com"
|
||||||
|
assert results[0].body_text == "Hello world"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_messages_returns_empty_list(self):
|
||||||
|
client = self._make_client()
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_users = mock_service.users.return_value
|
||||||
|
mock_messages = mock_users.messages.return_value
|
||||||
|
mock_messages.list.return_value.execute.return_value = {"messages": []}
|
||||||
|
|
||||||
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
mock_thread.side_effect = fake_to_thread
|
||||||
|
|
||||||
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||||
|
patch("google.auth.transport.requests.Request"), \
|
||||||
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_http_error_raises_runtime_error(self):
|
||||||
|
import googleapiclient.errors
|
||||||
|
client = self._make_client()
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_users = mock_service.users.return_value
|
||||||
|
mock_messages = mock_users.messages.return_value
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status = 403
|
||||||
|
mock_resp.reason = "Forbidden"
|
||||||
|
mock_messages.list.return_value.execute.side_effect = (
|
||||||
|
googleapiclient.errors.HttpError(mock_resp, b"Forbidden")
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
mock_thread.side_effect = fake_to_thread
|
||||||
|
|
||||||
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||||
|
patch("google.auth.transport.requests.Request"), \
|
||||||
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||||
|
with pytest.raises(RuntimeError, match="Gmail messages.list failed"):
|
||||||
|
await client.fetch_messages()
|
||||||
|
|
||||||
|
def test_refreshed_credentials_none_when_unchanged(self):
|
||||||
|
client = self._make_client()
|
||||||
|
# Token unchanged — should return None.
|
||||||
|
assert client.refreshed_credentials is None
|
||||||
|
|
||||||
|
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
||||||
|
client = self._make_client()
|
||||||
|
# Simulate a token refresh by changing the access token on the credentials object.
|
||||||
|
client._credentials.token = "new_access_token_xyz"
|
||||||
|
refreshed = client.refreshed_credentials
|
||||||
|
assert refreshed is not None
|
||||||
|
assert refreshed["token"] == "new_access_token_xyz"
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# MS Graph client \u2014 email filter builder
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildEmailFilter:
|
||||||
|
"""Unit tests for ms_graph._build_email_filter."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
from app.integrations.ms_graph import _build_email_filter
|
||||||
|
self._fn = _build_email_filter
|
||||||
|
|
||||||
|
def test_empty_returns_empty_string(self):
|
||||||
|
assert self._fn(None, None) == ""
|
||||||
|
|
||||||
|
def test_single_sender(self):
|
||||||
|
result = self._fn({"senders": ["alice@example.com"]}, None)
|
||||||
|
assert "from/emailAddress/address eq 'alice@example.com'" in result
|
||||||
|
|
||||||
|
def test_multiple_senders_joined_with_or(self):
|
||||||
|
result = self._fn({"senders": ["a@x.com", "b@x.com"]}, None)
|
||||||
|
assert " or " in result
|
||||||
|
assert "a@x.com" in result
|
||||||
|
assert "b@x.com" in result
|
||||||
|
|
||||||
|
def test_since_adds_received_date_ge_clause(self):
|
||||||
|
since = datetime(2025, 3, 1, tzinfo=timezone.utc)
|
||||||
|
result = self._fn(None, since)
|
||||||
|
assert "receivedDateTime ge 2025-03-01T00:00:00Z" in result
|
||||||
|
|
||||||
|
def test_date_range_to_adds_received_date_le_clause(self):
|
||||||
|
result = self._fn({"date_range": {"to": "2025-06-30"}}, None)
|
||||||
|
assert "receivedDateTime le" in result
|
||||||
|
|
||||||
|
def test_since_overrides_earlier_date_range_from(self):
|
||||||
|
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
||||||
|
result = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
||||||
|
assert "2025-02-01T00:00:00Z" in result
|
||||||
|
assert "2025-01-01" not in result
|
||||||
|
|
||||||
|
def test_invalid_date_ignored(self):
|
||||||
|
result = self._fn({"date_range": {"from": "bad-date"}}, None)
|
||||||
|
assert "receivedDateTime" not in result
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# MSGraphClient.fetch_emails
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_email(
|
||||||
|
msg_id: str = "email001",
|
||||||
|
subject: str = "Meeting tomorrow",
|
||||||
|
sender_address: str = "boss@company.com",
|
||||||
|
body_content: str = "Please prepare the report.",
|
||||||
|
received: str = "2025-06-01T10:00:00Z",
|
||||||
|
) -> dict:
|
||||||
|
"""Build a minimal MS Graph message item dict."""
|
||||||
|
return {
|
||||||
|
"id": msg_id,
|
||||||
|
"subject": subject,
|
||||||
|
"from": {"emailAddress": {"address": sender_address}},
|
||||||
|
"receivedDateTime": received,
|
||||||
|
"body": {"contentType": "text", "content": body_content},
|
||||||
|
"bodyPreview": body_content[:100],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_teams_message(
|
||||||
|
msg_id: str = "teams001",
|
||||||
|
content: str = "Stand-up at 9am",
|
||||||
|
sender: str = "alice",
|
||||||
|
channel_id: str = "chan001",
|
||||||
|
created: str = "2025-06-01T08:00:00Z",
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"id": msg_id,
|
||||||
|
"body": {"contentType": "text", "content": content},
|
||||||
|
"from": {"user": {"displayName": sender}},
|
||||||
|
"channelIdentity": {"channelId": channel_id},
|
||||||
|
"createdDateTime": created,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphClientFetchEmails:
|
||||||
|
"""MSGraphClient.fetch_emails tests with mocked httpx."""
|
||||||
|
|
||||||
|
def _make_client(self) -> "MSGraphClient":
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_returns_email_messages(self):
|
||||||
|
client = self._make_client()
|
||||||
|
graph_email = _make_graph_email()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": [graph_email]}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_emails()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].subject == "Meeting tomorrow"
|
||||||
|
assert results[0].sender == "boss@company.com"
|
||||||
|
assert results[0].body_text == "Please prepare the report."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pagination_stops_at_max_emails(self):
|
||||||
|
"""No nextLink in first page \u2014 only one batch returned."""
|
||||||
|
client = self._make_client()
|
||||||
|
emails_batch = [_make_graph_email(msg_id=str(i)) for i in range(3)]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": emails_batch} # no @odata.nextLink
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_emails()
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_401_triggers_token_refresh_and_retries(self):
|
||||||
|
"""On first 401, token refresh is attempted and the request retried."""
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
client = MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
graph_email = _make_graph_email()
|
||||||
|
|
||||||
|
response_401 = MagicMock()
|
||||||
|
response_401.status_code = 401
|
||||||
|
|
||||||
|
response_200 = MagicMock()
|
||||||
|
response_200.status_code = 200
|
||||||
|
response_200.json.return_value = {"value": [graph_email]}
|
||||||
|
response_200.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def fake_get(url, params=None, headers=None):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return response_401
|
||||||
|
return response_200
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls, \
|
||||||
|
patch.object(client, "_refresh_access_token", new_callable=AsyncMock) as mock_refresh:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = fake_get
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_emails()
|
||||||
|
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
def test_refreshed_credentials_none_when_token_unchanged(self):
|
||||||
|
client = self._make_client()
|
||||||
|
assert client.refreshed_credentials is None
|
||||||
|
|
||||||
|
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
||||||
|
client = self._make_client()
|
||||||
|
client._access_token = "new_token_abc"
|
||||||
|
assert client.refreshed_credentials is not None
|
||||||
|
assert client.refreshed_credentials["access_token"] == "new_token_abc"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphClientFetchMessages:
|
||||||
|
"""MSGraphClient.fetch_messages (Teams) tests."""
|
||||||
|
|
||||||
|
def _make_client(self) -> "MSGraphClient":
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_returns_chat_messages(self):
|
||||||
|
client = self._make_client()
|
||||||
|
teams_msg = _make_graph_teams_message()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": [teams_msg]}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].content == "Stand-up at 9am"
|
||||||
|
assert results[0].sender == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_403_degrades_gracefully(self):
|
||||||
|
"""getAllMessages returning 403 (license issue) returns empty list, no exception."""
|
||||||
|
import httpx as _httpx
|
||||||
|
|
||||||
|
client = self._make_client()
|
||||||
|
|
||||||
|
error_response = MagicMock()
|
||||||
|
error_response.status_code = 403
|
||||||
|
http_error = _httpx.HTTPStatusError(
|
||||||
|
"Forbidden", request=MagicMock(), response=error_response
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(side_effect=http_error)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_channel_filter_applied(self):
|
||||||
|
"""Messages from non-matching channels are filtered out."""
|
||||||
|
client = self._make_client()
|
||||||
|
matching = _make_graph_teams_message(channel_id="dev-channel", content="Deploy today")
|
||||||
|
non_matching = _make_graph_teams_message(msg_id="t2", channel_id="random", content="Lunch?")
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": [matching, non_matching]}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_messages(
|
||||||
|
filter_config={"channels": ["dev-channel"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].content == "Deploy today"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphClientRefreshToken:
|
||||||
|
"""MSGraphClient._refresh_access_token with mocked MSAL."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_msal_error_raises_runtime_error(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_test"})
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.acquire_token_by_refresh_token.return_value = {
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": "Refresh token expired",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
||||||
|
patch("app.integrations.ms_graph.settings") as mock_settings:
|
||||||
|
mock_settings.MS_CLIENT_ID = "client_id"
|
||||||
|
mock_settings.MS_CLIENT_SECRET = "secret"
|
||||||
|
mock_settings.MS_TENANT_ID = "common"
|
||||||
|
with pytest.raises(RuntimeError, match="MS Graph token refresh failed"):
|
||||||
|
await client._refresh_access_token()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_refresh_updates_access_token(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_old"})
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.acquire_token_by_refresh_token.return_value = {
|
||||||
|
"access_token": "new_access_token",
|
||||||
|
"refresh_token": "new_refresh_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
||||||
|
patch("app.integrations.ms_graph.settings") as mock_settings:
|
||||||
|
mock_settings.MS_CLIENT_ID = "client_id"
|
||||||
|
mock_settings.MS_CLIENT_SECRET = "secret"
|
||||||
|
mock_settings.MS_TENANT_ID = "common"
|
||||||
|
await client._refresh_access_token()
|
||||||
|
|
||||||
|
assert client._access_token == "new_access_token"
|
||||||
|
assert client._refresh_token == "new_refresh_token"
|
||||||
284
tests/test_memory_middleware.py
Normal file
284
tests/test_memory_middleware.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""Tests for Step 7 — MemoryMiddleware.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
1. enrich_context returns core prefs + associative + episodic + proactive
|
||||||
|
2. store_episode creates an encrypted row decryptable with the user's key
|
||||||
|
3. update_core upserts correctly
|
||||||
|
4. User with no encryption_key returns empty context (no crash)
|
||||||
|
5. End-to-end: home_request WS frame results in an episodic row being stored
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import (
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||||
|
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
_FERNET_KEY = Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB override ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def user_with_key(db_session):
|
||||||
|
"""Set encryption_key on the seeded power user."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _fernet():
|
||||||
|
return Fernet(_FERNET_KEY.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def _enc(plaintext: str) -> str:
|
||||||
|
return _fernet().encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _dec(ciphertext: str) -> str:
|
||||||
|
return _fernet().decrypt(ciphertext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── enrich_context ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_core_memory(db_session, user_with_key):
|
||||||
|
# Seed a core memory row
|
||||||
|
db_session.add(MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
key="timezone",
|
||||||
|
value_encrypted=_enc("UTC"),
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "What are my tasks?")
|
||||||
|
|
||||||
|
assert "core_memory" in ctx
|
||||||
|
assert ctx["core_memory"]["timezone"] == "UTC"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_episodic_memory(db_session, user_with_key):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("User asked about Q1 tasks"),
|
||||||
|
session_id=session_id,
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
||||||
|
|
||||||
|
assert "episodic_memory" in ctx
|
||||||
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
|
# Add one pattern above threshold and one below
|
||||||
|
db_session.add(MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=_enc("User prefers short summaries"),
|
||||||
|
confidence=0.9,
|
||||||
|
source="inferred",
|
||||||
|
))
|
||||||
|
db_session.add(MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=_enc("User likes dark mode"),
|
||||||
|
confidence=0.1,
|
||||||
|
source="inferred",
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
||||||
|
|
||||||
|
assert "proactive_hints" in ctx
|
||||||
|
hints = ctx["proactive_hints"]
|
||||||
|
assert any("short summaries" in h for h in hints)
|
||||||
|
assert not any("dark mode" in h for h in hints)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_associative_memory(db_session, user_with_key):
|
||||||
|
db_session.add(MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
content_encrypted=_enc("Related memory about meetings"),
|
||||||
|
embedding=None,
|
||||||
|
entity_type="note",
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "meetings")
|
||||||
|
|
||||||
|
assert "associative_memory" in ctx
|
||||||
|
assert any("meetings" in m for m in ctx["associative_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_empty_for_user_without_key(db_session):
|
||||||
|
"""User with no encryption_key → empty context, no crash."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = None
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "hello")
|
||||||
|
assert ctx == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ── store_episode ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_episode_creates_encrypted_row(db_session, user_with_key):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.store_episode(USER_ID, session_id, "hello", "world")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
plaintext = _dec(row.summary_encrypted)
|
||||||
|
assert "hello" in plaintext
|
||||||
|
assert "world" in plaintext
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_episode_decryptable(db_session, user_with_key):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.store_episode(USER_ID, session_id, "msg", "resp")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
# Decrypt using the same key — must not raise
|
||||||
|
decrypted = _dec(row.summary_encrypted)
|
||||||
|
assert len(decrypted) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── update_core ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_core_insert(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.update_core(USER_ID, "lang", "en")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert _dec(row.value_encrypted) == "en"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_core_upsert(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.update_core(USER_ID, "lang", "en")
|
||||||
|
await middleware.update_core(USER_ID, "lang", "fr")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert _dec(rows[0].value_encrypted) == "fr"
|
||||||
|
|
||||||
|
|
||||||
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
||||||
|
|
||||||
|
def test_home_request_calls_memory_middleware(client):
|
||||||
|
"""home_request triggers enrich_context before and store_episode after the LLM."""
|
||||||
|
enrich_calls: list[tuple] = []
|
||||||
|
store_calls: list[tuple] = []
|
||||||
|
|
||||||
|
class _MockMiddleware:
|
||||||
|
def __init__(self, db):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def enrich_context(self, user_id, message):
|
||||||
|
enrich_calls.append((user_id, message))
|
||||||
|
return {"core_memory": {"tz": "UTC"}}
|
||||||
|
|
||||||
|
async def store_episode(self, user_id, session_id, message, response):
|
||||||
|
store_calls.append((user_id, session_id, message, response))
|
||||||
|
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async def _mock_stream(user_id, message, context, db_session_factory=None):
|
||||||
|
# Verify memory context was injected
|
||||||
|
assert context.get("core_memory") == {"tz": "UTC"}
|
||||||
|
yield ("token", "Done")
|
||||||
|
yield ("mutations", [])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
||||||
|
patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
|
||||||
|
):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-mem", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": "r-mem",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Show tasks",
|
||||||
|
}))
|
||||||
|
for _ in range(20):
|
||||||
|
raw = ws.receive_text()
|
||||||
|
frame = json.loads(raw)
|
||||||
|
if frame.get("type") == "stream_end":
|
||||||
|
break
|
||||||
|
|
||||||
|
assert len(enrich_calls) == 1
|
||||||
|
assert enrich_calls[0] == (USER_ID, "Show tasks")
|
||||||
|
assert len(store_calls) == 1
|
||||||
|
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
||||||
|
assert stored_session_id == session_id
|
||||||
|
assert stored_message == "Show tasks"
|
||||||
205
tests/test_memory_models.py
Normal file
205
tests/test_memory_models.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Tests for Step 6 — memory ORM models and User.encryption_key.
|
||||||
|
|
||||||
|
Uses the SQLite in-memory test DB (from conftest). The pgvector embedding
|
||||||
|
column is stored as JSON in tests (SQLite-compatible).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.models import MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _fernet_key() -> str:
|
||||||
|
return Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt(key: str, plaintext: str) -> str:
|
||||||
|
return Fernet(key.encode()).encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _decrypt(key: str, ciphertext: str) -> str:
|
||||||
|
return Fernet(key.encode()).decrypt(ciphertext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── User.encryption_key ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_encryption_key_column_exists(db_session):
|
||||||
|
"""User model has encryption_key column and it can be set."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
# Column exists (may be None for seeded users)
|
||||||
|
assert hasattr(user, "encryption_key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_encryption_key_can_be_set(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = key
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result2 = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user2 = result2.scalar_one()
|
||||||
|
assert user2.encryption_key == key
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryCore ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_core_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
encrypted_val = _encrypt(key, "UTC")
|
||||||
|
|
||||||
|
row = MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
key="timezone",
|
||||||
|
value_encrypted=encrypted_val,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.key == "timezone"
|
||||||
|
assert _decrypt(key, fetched.value_encrypted) == "UTC"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_core_cascade_delete(db_session):
|
||||||
|
"""Deleting a user cascades to memory_core."""
|
||||||
|
row = MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
key="lang",
|
||||||
|
value_encrypted="enc",
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
user = (await db_session.execute(select(User).where(User.id == USER_ID))).scalar_one()
|
||||||
|
await db_session.delete(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
remaining = (
|
||||||
|
await db_session.execute(select(MemoryCore).where(MemoryCore.user_id == USER_ID))
|
||||||
|
).scalars().all()
|
||||||
|
assert remaining == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryAssociative ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_associative_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
content = _encrypt(key, "User prefers morning meetings")
|
||||||
|
embedding = [0.1] * 1536 # fake embedding
|
||||||
|
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
content_encrypted=content,
|
||||||
|
embedding=embedding,
|
||||||
|
entity_type="preference",
|
||||||
|
entity_id=None,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == USER_ID)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.entity_type == "preference"
|
||||||
|
assert _decrypt(key, fetched.content_encrypted) == "User prefers morning meetings"
|
||||||
|
assert len(fetched.embedding) == 1536
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryEpisodic ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_episodic_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
summary = _encrypt(key, "User asked about Q1 tasks")
|
||||||
|
|
||||||
|
row = MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=summary,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert _decrypt(key, fetched.summary_encrypted) == "User asked about Q1 tasks"
|
||||||
|
assert isinstance(fetched.created_at, datetime)
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryProactive ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_proactive_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
pattern = _encrypt(key, "User always assigns tasks to self")
|
||||||
|
|
||||||
|
row = MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=pattern,
|
||||||
|
confidence=0.85,
|
||||||
|
source="inferred",
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryProactive).where(MemoryProactive.user_id == USER_ID)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.confidence == pytest.approx(0.85)
|
||||||
|
assert fetched.source == "inferred"
|
||||||
|
assert _decrypt(key, fetched.pattern_encrypted) == "User always assigns tasks to self"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth registration generates encryption_key ───────────────────────────────
|
||||||
|
|
||||||
|
def test_register_sets_encryption_key(client):
|
||||||
|
"""POST /api/v1/auth/register creates a user with a valid Fernet key."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "newuser@test.com", "password": "testpassword123"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
|
||||||
|
# Fetch the newly created user via the access token
|
||||||
|
token = resp.json()["access_token"]
|
||||||
|
me_resp = client.get(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert me_resp.status_code == 200
|
||||||
|
# We can't see encryption_key in the API response (not in UserProfile),
|
||||||
|
# but we verify registration didn't crash — key generation is implicit.
|
||||||
@@ -20,7 +20,6 @@ from jose import jwt
|
|||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.schemas import ChatResponse
|
|
||||||
from tests.conftest import TEST_USER_IDS
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -50,7 +49,6 @@ _CHAT_BODY = {
|
|||||||
"recent_tasks": [],
|
"recent_tasks": [],
|
||||||
"conversation_history": [],
|
"conversation_history": [],
|
||||||
},
|
},
|
||||||
"execution_mode": "direct",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -240,7 +238,7 @@ class TestRateLimitMiddleware:
|
|||||||
|
|
||||||
|
|
||||||
class TestSanitizerMiddleware:
|
class TestSanitizerMiddleware:
|
||||||
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
|
"""Mock ``run_home`` to inject controlled strings into chat responses."""
|
||||||
|
|
||||||
_CHAT_PATH = "/api/v1/chat"
|
_CHAT_PATH = "/api/v1/chat"
|
||||||
|
|
||||||
@@ -248,11 +246,10 @@ class TestSanitizerMiddleware:
|
|||||||
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
||||||
|
|
||||||
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
||||||
mock_response = ChatResponse(response=response_text, actions=[])
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.chat.orchestrate",
|
"app.api.routes.chat.run_home",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=mock_response,
|
return_value=response_text,
|
||||||
):
|
):
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
self._CHAT_PATH,
|
self._CHAT_PATH,
|
||||||
|
|||||||
@@ -1,348 +0,0 @@
|
|||||||
"""Integration tests for the orchestrator module."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
from app.core.orchestrator import (
|
|
||||||
classify_intent,
|
|
||||||
orchestrate,
|
|
||||||
orchestrate_stream,
|
|
||||||
route_pipeline,
|
|
||||||
route_single,
|
|
||||||
)
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
|
||||||
|
|
||||||
|
|
||||||
# ── Stub agents ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _TaskAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "task_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages tasks: create, update, list, suggest"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"task: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
class _CalendarAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "calendar_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Calendar management: events, conflicts, scheduling"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"calendar: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm(response_text: str) -> MagicMock:
|
|
||||||
"""Return a mock LLM that always produces *response_text*."""
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.content = response_text
|
|
||||||
llm = MagicMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _fresh_registry():
|
|
||||||
"""Reset the AgentRegistry singleton between tests."""
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
yield
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def reg() -> AgentRegistry:
|
|
||||||
r = AgentRegistry()
|
|
||||||
r.register(_TaskAgent)
|
|
||||||
r.register(_CalendarAgent)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
# ── classify_intent ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestClassifyIntent:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
result = await classify_intent("add a task", {}, reg)
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
|
||||||
result = await classify_intent("schedule a meeting", {}, reg)
|
|
||||||
assert result == "calendar_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("nonexistent_agent")
|
|
||||||
result = await classify_intent("do something", {}, reg)
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
|
|
||||||
empty_reg = AgentRegistry()
|
|
||||||
# No LLM should be instantiated — early return path
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
result = await classify_intent("anything", {}, empty_reg)
|
|
||||||
mock_cls.assert_not_called()
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm(" task_agent \n")
|
|
||||||
result = await classify_intent("create task", {}, reg)
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
|
|
||||||
# ── route_single ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestRouteSingle:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
|
||||||
result = await route_single("task_agent", "create a task", {}, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
|
|
||||||
result = await route_single("task_agent", "create a task", {}, reg)
|
|
||||||
assert result.response == "task: create a task"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
await route_single("nonexistent", "hello", {}, reg)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
|
|
||||||
result = await route_single("task_agent", "hi", {}, reg)
|
|
||||||
assert result.actions == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── route_pipeline ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestRoutePipeline:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("synthesized result")
|
|
||||||
result = await route_pipeline(
|
|
||||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
|
||||||
)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("synthesized result")
|
|
||||||
result = await route_pipeline(
|
|
||||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
|
||||||
)
|
|
||||||
assert result.response == "synthesized result"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_passes_previous_results_to_subsequent_agents(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
"""Each agent after the first should receive prior outputs in context."""
|
|
||||||
received_contexts: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
class _CapturingAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "capture"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "captures context for testing"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
received_contexts.append(dict(context))
|
|
||||||
return "captured"
|
|
||||||
|
|
||||||
reg.register(_CapturingAgent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("done")
|
|
||||||
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
|
|
||||||
|
|
||||||
# The second agent (capture) must have received previous results
|
|
||||||
assert len(received_contexts) == 1
|
|
||||||
assert "previous_results" in received_contexts[0]
|
|
||||||
assert received_contexts[0]["previous_results"] == ["task: hi"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("single result")
|
|
||||||
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
|
|
||||||
assert result.response == "single result"
|
|
||||||
|
|
||||||
|
|
||||||
# ── orchestrate ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestOrchestrate:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_direct_mode_returns_chat_response(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
assert result.response == "task: add a task"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_returns_execution_plan(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="plan my tasks", execution_mode="plan")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_agent_matches_classified(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
|
||||||
request = ChatRequest(
|
|
||||||
message="schedule something", execution_mode="plan"
|
|
||||||
)
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
assert result.agent == "calendar_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
assert len(result.steps) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_template_id_contains_agent_name(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
assert result.steps[0].prompt_template is not None
|
|
||||||
assert "task_agent" in result.steps[0].prompt_template
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_default_execution_mode_is_direct(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
# execution_mode defaults to "direct"
|
|
||||||
request = ChatRequest(message="help me")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
|
|
||||||
# ── orchestrate_stream ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestOrchestrateStream:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
assert len(chunks) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_last_chunk_is_final_json_frame(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
|
|
||||||
last = json.loads(chunks[-1])
|
|
||||||
assert last["done"] is True
|
|
||||||
assert "response" in last
|
|
||||||
assert "actions" in last
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_final_frame_response_matches_agent_output(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="create a task", execution_mode="direct")
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
|
|
||||||
final = json.loads(chunks[-1])
|
|
||||||
assert final["response"] == "task: create a task"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_chunks_before_final_frame(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(
|
|
||||||
message="x" * 200, execution_mode="direct"
|
|
||||||
) # long enough to produce multiple chunks
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
|
|
||||||
# All but the last chunk should be plain text (not valid final JSON)
|
|
||||||
non_final = chunks[:-1]
|
|
||||||
for chunk in non_final:
|
|
||||||
try:
|
|
||||||
parsed = json.loads(chunk)
|
|
||||||
assert parsed.get("done") is not True
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass # plain text chunk — expected
|
|
||||||
214
tests/test_output_formatter.py
Normal file
214
tests/test_output_formatter.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _stream(*events: tuple[str, object]):
|
||||||
|
"""Async generator that yields (event_type, data) tuples."""
|
||||||
|
for event in events:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def collect(formatter, event_stream):
|
||||||
|
frames = []
|
||||||
|
async for frame in formatter.format(event_stream):
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_plain_text():
|
||||||
|
req_id = "req-1"
|
||||||
|
events = [
|
||||||
|
("token", "Hello world"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert frames[0].request_id == req_id
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert any("Hello world" in f.chunk for f in text_frames)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_entity_tags_passed_through():
|
||||||
|
"""Entity tags are streamed as-is — the frontend parses them."""
|
||||||
|
req_id = "req-2"
|
||||||
|
events = [
|
||||||
|
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert "<project>[abc-123]</project>" in text
|
||||||
|
assert "Here is your project:" in text
|
||||||
|
assert "All good." in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_multiple_tags_passed_through():
|
||||||
|
req_id = "req-3"
|
||||||
|
events = [
|
||||||
|
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert "<project>[p1]</project>" in text
|
||||||
|
assert "<task>[t1,t2]</task>" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_tool_end_ignored():
|
||||||
|
"""tool_end events are silently ignored by HomeFormatter."""
|
||||||
|
req_id = "req-4"
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
||||||
|
("token", "No tags here."),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert text == "No tags here."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_mutations_in_stream_end():
|
||||||
|
req_id = "req-5"
|
||||||
|
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
||||||
|
events = [
|
||||||
|
("token", "Done"),
|
||||||
|
("mutations", muts),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
end_frame = frames[-1]
|
||||||
|
assert isinstance(end_frame, WsStreamEnd)
|
||||||
|
assert len(end_frame.mutations) == 1
|
||||||
|
assert end_frame.mutations[0]["action"] == "insert"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_frame_order():
|
||||||
|
"""stream_start is first, stream_end is last."""
|
||||||
|
req_id = "req-6"
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_domain_from_tool_end():
|
||||||
|
req_id = "pop-1"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "task_agent", "result": "ok"}),
|
||||||
|
("token", "Hello"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "tasks"
|
||||||
|
assert frames[0].request_id == req_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_text_only():
|
||||||
|
req_id = "pop-2"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "timeline_agent", "result": "done"}),
|
||||||
|
("token", "Summary"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "timelines"
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert len(text_frames) == 1
|
||||||
|
assert text_frames[0].chunk == "Summary"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_no_entity_tags():
|
||||||
|
"""FloatingFormatter never emits entity tag blocks."""
|
||||||
|
req_id = "pop-3"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "note_agent", "result": "data"}),
|
||||||
|
("token", "some text"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
# Only expected frame types
|
||||||
|
for f in frames:
|
||||||
|
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_end_frame():
|
||||||
|
req_id = "pop-4"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "project_agent", "result": "ok"}),
|
||||||
|
("token", "Done"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_default_domain_on_early_token():
|
||||||
|
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
|
||||||
|
req_id = "pop-5"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [("token", "hi"), ("mutations", [])]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "tasks"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_mutations_in_stream_end():
|
||||||
|
req_id = "pop-6"
|
||||||
|
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
|
||||||
|
events = [
|
||||||
|
("token", "Updated"),
|
||||||
|
("mutations", muts),
|
||||||
|
]
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
end_frame = frames[-1]
|
||||||
|
assert isinstance(end_frame, WsStreamEnd)
|
||||||
|
assert len(end_frame.mutations) == 1
|
||||||
@@ -88,7 +88,7 @@ class TestPluginRegistry:
|
|||||||
async def test_list_filter_by_query(
|
async def test_list_filter_by_query(
|
||||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
) -> None:
|
) -> None:
|
||||||
result = await reg.list_plugins(db_session, query="time")
|
result = await reg.list_plugins(db_session, query="time tracker")
|
||||||
assert result.total == 1
|
assert result.total == 1
|
||||||
assert result.plugins[0].id == "plugin-time-tracker"
|
assert result.plugins[0].id == "plugin-time-tracker"
|
||||||
|
|
||||||
|
|||||||
230
tests/test_schemas_v3.py
Normal file
230
tests/test_schemas_v3.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""Tests for v3 WebSocket frame protocol schemas."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.schemas import (
|
||||||
|
WsFrameType,
|
||||||
|
WsHomeRequest,
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsFloatingRequest,
|
||||||
|
WsFloatingScope,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsFrameType ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_v3_frame_types_exist():
|
||||||
|
v3_types = [
|
||||||
|
"home_request",
|
||||||
|
"floating_request",
|
||||||
|
"stream_start",
|
||||||
|
"stream_text",
|
||||||
|
"stream_end",
|
||||||
|
"floating_domain",
|
||||||
|
"data_request",
|
||||||
|
"data_response",
|
||||||
|
"mutation",
|
||||||
|
]
|
||||||
|
for name in v3_types:
|
||||||
|
assert hasattr(WsFrameType, name), f"WsFrameType missing: {name}"
|
||||||
|
assert WsFrameType[name].value == name
|
||||||
|
|
||||||
|
|
||||||
|
def test_v2_frame_types_still_exist():
|
||||||
|
"""Backward compat: v2 types must remain."""
|
||||||
|
v2_types = [
|
||||||
|
"chat_request",
|
||||||
|
"text_chunk",
|
||||||
|
"tool_call",
|
||||||
|
"tool_result",
|
||||||
|
"final",
|
||||||
|
"ping",
|
||||||
|
"agent_run",
|
||||||
|
"agent_data",
|
||||||
|
"agent_complete",
|
||||||
|
"device_hello",
|
||||||
|
]
|
||||||
|
for name in v2_types:
|
||||||
|
assert hasattr(WsFrameType, name), f"v2 WsFrameType missing: {name}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsHomeRequest ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_defaults():
|
||||||
|
frame = WsHomeRequest(message="Hello")
|
||||||
|
assert frame.type == WsFrameType.home_request
|
||||||
|
assert frame.message == "Hello"
|
||||||
|
assert frame.conversation_history == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_with_history():
|
||||||
|
history = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]
|
||||||
|
frame = WsHomeRequest(message="Follow up", conversation_history=history)
|
||||||
|
assert frame.conversation_history == history
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_serializes():
|
||||||
|
frame = WsHomeRequest(message="Test")
|
||||||
|
data = frame.model_dump()
|
||||||
|
assert data["type"] == "home_request"
|
||||||
|
assert data["message"] == "Test"
|
||||||
|
assert data["conversation_history"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_deserializes():
|
||||||
|
raw = {"type": "home_request", "message": "Hi there"}
|
||||||
|
frame = WsHomeRequest.model_validate(raw)
|
||||||
|
assert frame.message == "Hi there"
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_requires_message():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsHomeRequest.model_validate({"type": "home_request"})
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsFloatingRequest ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_basic():
|
||||||
|
frame = WsFloatingRequest(
|
||||||
|
message="Summarise",
|
||||||
|
scope=WsFloatingScope(type="task", id="task-123"),
|
||||||
|
)
|
||||||
|
assert frame.type == WsFrameType.floating_request
|
||||||
|
assert frame.scope.type == "task"
|
||||||
|
assert frame.scope.id == "task-123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_scope_without_id():
|
||||||
|
frame = WsFloatingRequest(
|
||||||
|
message="Show all",
|
||||||
|
scope=WsFloatingScope(type="project"),
|
||||||
|
)
|
||||||
|
assert frame.scope.id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_serializes():
|
||||||
|
frame = WsFloatingRequest(
|
||||||
|
message="Test",
|
||||||
|
scope=WsFloatingScope(type="note", id="n-1"),
|
||||||
|
)
|
||||||
|
data = frame.model_dump()
|
||||||
|
assert data["type"] == "floating_request"
|
||||||
|
assert data["scope"]["type"] == "note"
|
||||||
|
assert data["scope"]["id"] == "n-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_invalid_scope_type():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsFloatingRequest(
|
||||||
|
message="X",
|
||||||
|
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_requires_scope():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamStart ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_start():
|
||||||
|
frame = WsStreamStart(request_id="req-abc")
|
||||||
|
assert frame.type == WsFrameType.stream_start
|
||||||
|
assert frame.request_id == "req-abc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_start_serializes():
|
||||||
|
data = WsStreamStart(request_id="r1").model_dump()
|
||||||
|
assert data == {"type": "stream_start", "request_id": "r1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_start_deserializes():
|
||||||
|
frame = WsStreamStart.model_validate({"type": "stream_start", "request_id": "r1"})
|
||||||
|
assert frame.request_id == "r1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamText ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_text():
|
||||||
|
frame = WsStreamText(request_id="r1", chunk="Hello ")
|
||||||
|
assert frame.type == WsFrameType.stream_text
|
||||||
|
assert frame.chunk == "Hello "
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_text_serializes():
|
||||||
|
data = WsStreamText(request_id="r1", chunk="word").model_dump()
|
||||||
|
assert data == {"type": "stream_text", "request_id": "r1", "chunk": "word"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_text_deserializes():
|
||||||
|
raw = {"type": "stream_text", "request_id": "r2", "chunk": "test"}
|
||||||
|
frame = WsStreamText.model_validate(raw)
|
||||||
|
assert frame.chunk == "test"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamEnd ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_defaults():
|
||||||
|
frame = WsStreamEnd(request_id="r1")
|
||||||
|
assert frame.type == WsFrameType.stream_end
|
||||||
|
assert frame.mutations == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_with_mutations():
|
||||||
|
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
||||||
|
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
||||||
|
assert len(frame.mutations) == 1
|
||||||
|
assert frame.mutations[0]["action"] == "create"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_serializes():
|
||||||
|
data = WsStreamEnd(request_id="r2").model_dump()
|
||||||
|
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_deserializes():
|
||||||
|
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
||||||
|
frame = WsStreamEnd.model_validate(raw)
|
||||||
|
assert frame.request_id == "r3"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsFloatingDomain ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_tasks():
|
||||||
|
frame = WsFloatingDomain(request_id="r1", domain="tasks")
|
||||||
|
assert frame.type == WsFrameType.floating_domain
|
||||||
|
assert frame.domain == "tasks"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
|
||||||
|
def test_floating_domain_valid_domains(domain: str):
|
||||||
|
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
||||||
|
assert frame.domain == domain
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_invalid():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_serializes():
|
||||||
|
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
|
||||||
|
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_deserializes():
|
||||||
|
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
|
||||||
|
frame = WsFloatingDomain.model_validate(raw)
|
||||||
|
assert frame.domain == "projects"
|
||||||
158
tests/test_ws_unified.py
Normal file
158
tests/test_ws_unified.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Integration tests for the unified WebSocket handler (Step 5).
|
||||||
|
|
||||||
|
Tests the device WS endpoint with home_request and floating_request frames,
|
||||||
|
verifying that the correct v3 frame sequence is returned.
|
||||||
|
|
||||||
|
LLM calls are mocked to avoid network dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
||||||
|
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
|
||||||
|
frames = []
|
||||||
|
for _ in range(max_frames):
|
||||||
|
raw = ws.receive_text()
|
||||||
|
frame = json.loads(raw)
|
||||||
|
frames.append(frame)
|
||||||
|
if frame.get("type") == WsFrameType.stream_end:
|
||||||
|
break
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
|
||||||
|
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
|
||||||
|
yield "tool_end", {"name": "task_agent", "result": "ok"}
|
||||||
|
yield "token", "Here is a summary"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
|
||||||
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_home_request_produces_stream_frames(client):
|
||||||
|
"""home_request → stream_start, stream_text+, stream_end."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": "r1",
|
||||||
|
"message": "List my tasks",
|
||||||
|
"conversation_history": [],
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
types = [f["type"] for f in frames]
|
||||||
|
assert WsFrameType.stream_start in types
|
||||||
|
assert WsFrameType.stream_end in types
|
||||||
|
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_produces_domain_frame(client):
|
||||||
|
"""floating_request → floating_domain first, then stream_text*, stream_end."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "floating_request",
|
||||||
|
"request_id": "p1",
|
||||||
|
"message": "Summarize this task",
|
||||||
|
"scope": {"type": "task", "id": "task-123"},
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
types = [f["type"] for f in frames]
|
||||||
|
assert WsFrameType.floating_domain in types
|
||||||
|
assert WsFrameType.stream_end in types
|
||||||
|
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
|
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
||||||
|
assert domain_frame["domain"] == "tasks"
|
||||||
|
assert domain_frame["request_id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_request_id_propagated(client):
|
||||||
|
"""request_id in home_request is echoed in all response frames."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
req_id = "my-unique-req-id"
|
||||||
|
|
||||||
|
async def _stream(user_id, message, context, db_session_factory=None):
|
||||||
|
yield "token", "ok"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": req_id,
|
||||||
|
"message": "hello",
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
for f in frames:
|
||||||
|
if "request_id" in f:
|
||||||
|
assert f["request_id"] == req_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_dispatch_silent_on_unknown_id(client):
|
||||||
|
"""tool_result for unknown call_id is silently ignored — no crash."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-4", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "tool_result", "id": "no-such-id", "ok": True
|
||||||
|
}))
|
||||||
|
# If connection is still alive, we'll get the heartbeat ping
|
||||||
|
msg = json.loads(ws.receive_text())
|
||||||
|
assert msg["type"] == "ping"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_jwt_rejected(client):
|
||||||
|
"""Connection with bad token is closed before or after accept."""
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
|
||||||
|
ws.receive_text()
|
||||||
Reference in New Issue
Block a user