Compare commits
22 Commits
feature/de
...
6c450805cb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c450805cb | ||
|
|
f340d0fa3e | ||
|
|
edc53cb6eb | ||
|
|
725cece5c1 | ||
|
|
297e20ce8d | ||
|
|
5a03bd1cfb | ||
|
|
87b7a1c6c9 | ||
|
|
826f64d6bb | ||
| 5faa6b1d7c | |||
| 02a9684cd6 | |||
| fae9efee0d | |||
| 30b062dd4a | |||
| 2a0331d7ce | |||
| 13fd8677c1 | |||
| 9bd629cb59 | |||
| 9c97702daa | |||
| a1e364c9c0 | |||
| 5b55f1292a | |||
| 5bc9ea6cd6 | |||
| f7404b6f66 | |||
| d667e43c73 | |||
| fe085a7951 |
@@ -1,523 +0,0 @@
|
|||||||
# AI Refactor Plan — Adiuva Backend
|
|
||||||
|
|
||||||
> **Objective:** Transform backend tools from JSON-action-descriptor-returning functions into real bidirectional executors. Each tool sends structured CRUD operations to the Electron client via WebSocket, receives real data back, and returns meaningful results to the LLM. The LLM reasons about actual user data instead of serialized action payloads.
|
|
||||||
>
|
|
||||||
> **Electron app:** Lives at `../adiuva/`. See `../adiuva/AI_REFACTOR_PLAN.md`.
|
|
||||||
>
|
|
||||||
> **Protocol:** Execute steps sequentially. Each step is atomic and committable. Mark `[x]` when done.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture — Before vs After
|
|
||||||
|
|
||||||
### Before (current)
|
|
||||||
```
|
|
||||||
LLM calls list_tasks(status="todo")
|
|
||||||
→ tool returns: '{"action":"list","table":"tasks","filters":{"status":"todo"}}'
|
|
||||||
→ _tool_loop feeds that JSON string as ToolMessage to LLM
|
|
||||||
→ LLM sees a descriptor, NOT real data — cannot reason about tasks
|
|
||||||
→ Final response: generic "Here are your tasks" (no actual task data)
|
|
||||||
→ Action descriptors sent in final WS frame for Electron to execute post-response
|
|
||||||
```
|
|
||||||
|
|
||||||
### After (target)
|
|
||||||
```
|
|
||||||
LLM calls list_tasks(status="todo")
|
|
||||||
→ tool calls execute_on_client(action="select", table="tasks", filters={status:"todo"})
|
|
||||||
→ WS frame sent to Electron: {type:"tool_call", id:"abc", action:"select", table:"tasks", filters:{status:"todo"}}
|
|
||||||
→ Electron runs: db.select().from(tasks).where(eq(tasks.status, "todo")).all()
|
|
||||||
→ WS frame back: {type:"tool_result", id:"abc", rows:[{id:"1",title:"Buy milk",...}, ...]}
|
|
||||||
→ tool returns: "Found 3 tasks: 1. Buy milk (high, due tomorrow) 2. ..."
|
|
||||||
→ _tool_loop feeds that as ToolMessage to LLM
|
|
||||||
→ LLM sees REAL data — can reason, count, compare, summarize
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## WS Protocol — Typed Frames
|
|
||||||
|
|
||||||
| Direction | `type` | Payload |
|
|
||||||
|---|---|---|
|
|
||||||
| Client → Server | `chat_request` | `{ message: str, context: ChatContext }` |
|
|
||||||
| Server → Client | `text_chunk` | `{ text: str }` |
|
|
||||||
| Server → Client | `tool_call` | `{ id: str, action: str, table?: str, data?: dict, filters?: dict, vector?: list[float], limit?: int }` |
|
|
||||||
| Client → Server | `tool_result` | `{ id: str, row?: dict, rows?: list[dict], results?: list[dict], deleted?: bool, ok?: bool, error?: str }` |
|
|
||||||
| Server → Client | `final` | `{ response: str }` |
|
|
||||||
| Server → Client | `ping` | `{}` |
|
|
||||||
|
|
||||||
**Actions:**
|
|
||||||
|
|
||||||
| `action` | What Electron does (Drizzle) | `tool_result` shape |
|
|
||||||
|---|---|---|
|
|
||||||
| `select` | `db.select().from(table).where(filters)` | `{ rows: [...] }` |
|
|
||||||
| `get` | `db.select().from(table).where(id=...).get()` | `{ row: {...} or null }` |
|
|
||||||
| `insert` | `db.insert(table).values({id: uuid(), ...data}).returning().get()` | `{ row: {...} }` |
|
|
||||||
| `update` | `db.update(table).set(updates).where(id=...).returning().get()` | `{ row: {...} }` |
|
|
||||||
| `delete` | `db.delete(table).where(id=...).run()` | `{ deleted: true }` |
|
|
||||||
| `vector_upsert` | LanceDB upsert with pre-computed vector | `{ ok: true }` |
|
|
||||||
| `vector_search` | LanceDB search by vector | `{ results: [{id, content, score}...] }` |
|
|
||||||
|
|
||||||
**Electron generates IDs + timestamps.** Backend tools never send `id` or `createdAt` in `insert` data — Electron adds `id: uuid()`, `createdAt: Date.now()`, `updatedAt: Date.now()`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## SQLite Schema Reference (Electron's local database)
|
|
||||||
|
|
||||||
Tools must use **camelCase** field names (Drizzle maps them to snake_case internally):
|
|
||||||
|
|
||||||
| Table | Columns |
|
|
||||||
|---|---|
|
|
||||||
| `tasks` | id, projectId, title, description, status (todo\|in_progress\|done), priority (high\|medium\|low), assignee (JSON array string), dueDate (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
|
||||||
| `projects` | id, clientId, name, status (active\|archived), aiSummary, createdAt (ms) |
|
|
||||||
| `timelines` | id, projectId (required), title, date (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
|
||||||
| `notes` | id, projectId, title, content (markdown), createdAt (ms), updatedAt (ms) |
|
|
||||||
| `taskComments` | id, taskId, author, content, createdAt (ms) |
|
|
||||||
| `clients` | id, parentId, name, industry, createdAt (ms) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase B — Backend Changes
|
|
||||||
|
|
||||||
### Step B.1 — WS context + frame types
|
|
||||||
- [x] Create `app/core/ws_context.py` (~25 lines):
|
|
||||||
- `_client_executor: ContextVar[Callable]` — holds the async callback for the current WS session
|
|
||||||
- `async def execute_on_client(action, table=None, data=None, filters=None, vector=None, limit=None) -> dict`:
|
|
||||||
- Reads callback from ContextVar
|
|
||||||
- Builds `tool_call` payload: `{id: str(uuid4()), action, table, data, filters, vector, limit}` (omits None fields)
|
|
||||||
- Calls `await callback(payload)` — which sends the WS frame and waits for `tool_result`
|
|
||||||
- Returns the result dict
|
|
||||||
- `def set_client_executor(fn)` / `def clear_client_executor()` — ContextVar management
|
|
||||||
- [x] Add to `app/schemas.py`:
|
|
||||||
- `WsFrameType(str, Enum)`: `chat_request`, `text_chunk`, `tool_call`, `tool_result`, `final`, `ping`
|
|
||||||
- `WsToolCall(BaseModel)`: `type`, `id`, `action`, `table?`, `data?`, `filters?`, `vector?`, `limit?`
|
|
||||||
- `WsToolResult(BaseModel)`: `type`, `id`, `row?`, `rows?`, `results?`, `deleted?`, `ok?`, `error?`
|
|
||||||
- `WsTextChunk(BaseModel)`: `type`, `text`
|
|
||||||
- `WsFinal(BaseModel)`: `type`, `response`
|
|
||||||
- **Files:** `app/core/ws_context.py`, `app/schemas.py`
|
|
||||||
- **Outcome:** Any tool can `await execute_on_client(...)` to query/mutate the user's local DB.
|
|
||||||
|
|
||||||
### Step B.2 — Rewrite all 23 tools to use `execute_on_client()`
|
|
||||||
- [x] Each tool: same `@tool` decorator, same parameters, same docstring. Replace `return json.dumps({...})` body with:
|
|
||||||
1. Call `result = await execute_on_client(action=..., table=..., data/filters=...)`
|
|
||||||
2. Return human-readable string with confirmation + key data from `result`
|
|
||||||
|
|
||||||
- [x] **`app/agents/task_agent.py` (8 tools):**
|
|
||||||
- `list_tasks(project_id, status, search, order_by)`:
|
|
||||||
```python
|
|
||||||
result = await execute_on_client(action="select", table="tasks", filters={
|
|
||||||
"projectId": project_id or None,
|
|
||||||
"status": status or None,
|
|
||||||
"search": search or None,
|
|
||||||
"orderBy": order_by or None,
|
|
||||||
})
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No tasks found matching the given filters."
|
|
||||||
lines = [f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" for r in rows]
|
|
||||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
|
||||||
```
|
|
||||||
- `create_task(title, ...)`:
|
|
||||||
```python
|
|
||||||
result = await execute_on_client(action="insert", table="tasks", data={
|
|
||||||
"title": title, "description": description or None, "status": status,
|
|
||||||
"priority": priority, "assignee": assignees, "dueDate": due_date or None,
|
|
||||||
"projectId": project_id or None, "isAiSuggested": is_ai_suggested, "isApproved": is_approved,
|
|
||||||
})
|
|
||||||
row = result["row"]
|
|
||||||
return f"Task created: '{row['title']}' (id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
|
||||||
```
|
|
||||||
- `update_task(task_id, ...)`: build updates dict (same logic as now) → `execute_on_client(action="update", table="tasks", data={"id": task_id, "updates": updates})` → return "Task updated: {title}"
|
|
||||||
- `delete_task(task_id)`: `execute_on_client(action="delete", table="tasks", data={"id": task_id})` → return "Task deleted"
|
|
||||||
- `list_tasks_due_today()`: calculate today's start/end ms → `execute_on_client(action="select", table="tasks", filters={"dueDateFrom": start, "dueDateTo": end})` → format + return
|
|
||||||
- `list_task_comments(task_id)`: `execute_on_client(action="select", table="taskComments", filters={"taskId": task_id})` → format + return
|
|
||||||
- `add_task_comment(task_id, author, content)`: `execute_on_client(action="insert", table="taskComments", data={...})` → return confirmation
|
|
||||||
- `delete_task_comment(comment_id)`: `execute_on_client(action="delete", table="taskComments", data={"id": comment_id})` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/project_agent.py` (6 tools):**
|
|
||||||
- `list_projects(client_id, include_archived)`: `execute_on_client(action="select", table="projects", filters={clientId, includeArchived})` → format + return
|
|
||||||
- `list_all_projects()`: `execute_on_client(action="select", table="projects")` → format + return
|
|
||||||
- `get_project(project_id)`: `execute_on_client(action="get", table="projects", data={"id": project_id})` → return project details or "not found"
|
|
||||||
- `create_project(name, client_id)`: `execute_on_client(action="insert", table="projects", data={name, clientId})` → return confirmation + id
|
|
||||||
- `update_project(project_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
|
||||||
- `delete_project(project_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/timeline_agent.py` (4 tools):**
|
|
||||||
- `list_timelines(project_id)`: `execute_on_client(action="select", table="timelines", filters={projectId})` → format + return
|
|
||||||
- `create_timeline(project_id, title, date, ...)`: `execute_on_client(action="insert", table="timelines", data={...})` → return confirmation + id
|
|
||||||
- `update_timeline(timeline_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
|
||||||
- `delete_timeline(timeline_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/note_agent.py` (5 tools):**
|
|
||||||
- `list_notes(project_id)`: `execute_on_client(action="select", table="notes", filters={projectId})` → format + return
|
|
||||||
- `get_note(note_id)`: `execute_on_client(action="get", table="notes", data={"id": note_id})` → return full content or "not found"
|
|
||||||
- `create_note(title, content, project_id)`: `execute_on_client(action="insert", table="notes", data={...})` → then `execute_on_client(action="vector_upsert", data={id, projectId, content}, vector=await embed(content))` → return confirmation
|
|
||||||
- `update_note(note_id, ...)`: build updates → `execute_on_client(action="update", ...)` → then vector_upsert for updated content → return confirmation
|
|
||||||
- `delete_note(note_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- **Files:** `app/agents/task_agent.py`, `app/agents/project_agent.py`, `app/agents/timeline_agent.py`, `app/agents/note_agent.py`
|
|
||||||
- **Outcome:** All 23 tools query real user data via WS. LLM sees actual rows, not action descriptors.
|
|
||||||
|
|
||||||
### Step B.3 — Bidirectional WebSocket handler
|
|
||||||
- [x] Refactor `app/api/routes/chat.py` WS endpoint:
|
|
||||||
- After auth + accept + receive `chat_request`:
|
|
||||||
1. Create `execute_on_client` callback closure capturing the websocket:
|
|
||||||
```python
|
|
||||||
pending_calls: dict[str, asyncio.Future] = {}
|
|
||||||
|
|
||||||
async def on_client_result(frame: dict):
|
|
||||||
"""Called when a tool_result frame arrives from Electron."""
|
|
||||||
fut = pending_calls.pop(frame["id"], None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(frame)
|
|
||||||
|
|
||||||
async def execute_callback(payload: dict) -> dict:
|
|
||||||
"""Send tool_call to Electron, wait for tool_result."""
|
|
||||||
call_id = payload["id"]
|
|
||||||
fut = asyncio.get_event_loop().create_future()
|
|
||||||
pending_calls[call_id] = fut
|
|
||||||
await websocket.send_text(json.dumps({"type": "tool_call", **payload}))
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
```
|
|
||||||
2. Set `client_executor` ContextVar with `execute_callback`
|
|
||||||
3. Run orchestrator in a task — it calls agents, agents call tools, tools call `execute_on_client()` which goes through the callback
|
|
||||||
4. In parallel, run a message receive loop that dispatches incoming frames:
|
|
||||||
- `tool_result` → `on_client_result(frame)`
|
|
||||||
- `ping` → ignore
|
|
||||||
5. Orchestrator yields `text_chunk` frames → send to client
|
|
||||||
6. Send `final` frame when done
|
|
||||||
7. Clear ContextVar
|
|
||||||
- Keep heartbeat ping every 30s
|
|
||||||
- 30s timeout on `tool_result` — if Electron doesn't respond, future raises `TimeoutError`, tool returns error string to LLM
|
|
||||||
- **Files:** `app/api/routes/chat.py`
|
|
||||||
- **Outcome:** Full bidirectional WS. Tool calls and text streaming happen concurrently on the same connection.
|
|
||||||
|
|
||||||
### Step B.4 — `_tool_loop` — no changes needed
|
|
||||||
- [x] Verify `app/core/agent_registry.py` works unchanged:
|
|
||||||
- `_tool_loop` calls `tool_fn.ainvoke(args)` → tool awaits `execute_on_client()` (WS round-trip) → returns string → `ToolMessage(content=string)` → LLM sees real data
|
|
||||||
- The async WS round-trip happens inside each tool. `_tool_loop` just sees an awaited tool returning a string — same as before, different content.
|
|
||||||
- **No code changes.** Just verify + add a log line for tool execution times if desired.
|
|
||||||
|
|
||||||
### Step B.5 — Orchestrator cleanup
|
|
||||||
- [x] Update `app/core/orchestrator.py`:
|
|
||||||
- `orchestrate_stream()`: remove `"actions": []` from final frame. Final becomes: `{"done": true, "response": "..."}`
|
|
||||||
- No other changes — `classify_intent` → `call_agent` → chunk response → final frame
|
|
||||||
- **Files:** `app/core/orchestrator.py`
|
|
||||||
- **Outcome:** Clean final frame. No more action descriptors in the protocol.
|
|
||||||
|
|
||||||
### Step B.6 — Add `/vectors/embed` endpoint
|
|
||||||
- [x] Add to `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/embed`:
|
|
||||||
- Request: `{ text: str }`
|
|
||||||
- Response: `{ vector: list[float] }` (1536-dim from `text-embedding-3-small`)
|
|
||||||
- Auth required (JWT)
|
|
||||||
- Used by:
|
|
||||||
- Backend tools: `note_agent` calls this before `vector_upsert`
|
|
||||||
- Electron: `vectordb.ts` calls this for note embedding on create/update
|
|
||||||
- **Files:** `app/api/routes/vectors.py`
|
|
||||||
- **Outcome:** Single embedding endpoint. Both backend tools and Electron can generate vectors.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Verification
|
|
||||||
|
|
||||||
| What to test | How |
|
|
||||||
|---|---|
|
|
||||||
| **Read flow** | "List my tasks" → `list_tasks` → `tool_call{select, tasks}` → Electron returns rows → LLM describes real tasks |
|
|
||||||
| **Write flow** | "Create a task called Buy milk" → `create_task` → `tool_call{insert, tasks, data:{title:"Buy milk"}}` → Electron inserts + returns row → tool confirms with id |
|
|
||||||
| **Multi-tool** | "How many todo tasks do I have?" → `list_tasks(status=todo)` → LLM counts actual rows → "You have 3 todo tasks" |
|
|
||||||
| **Vector search** | "Find notes about deployment" → tool embeds → `tool_call{vector_search, vector:[...]}` → Electron searches LanceDB → returns matching notes |
|
|
||||||
| **Vector upsert** | "Create a note about..." → insert note → vector_upsert with embedding → both SQLite + LanceDB updated |
|
|
||||||
| **Tool timeout** | Disconnect Electron mid-conversation → 30s timeout → tool returns error → LLM handles gracefully |
|
|
||||||
| **Concurrent calls** | Agent calls 2 tools in sequence → each does WS round-trip → both succeed → LLM sees both results |
|
|
||||||
| **_tool_loop max iter** | Verify 5-iteration limit still works → after 5 tool calls, LLM forced to answer without tools |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Execution Notes
|
|
||||||
|
|
||||||
- **Phase 1 is the critical path.** Auth + backend client + drizzle executor + orchestrator refactor must land first.
|
|
||||||
- **Steps 1.1–1.4 are additive** — existing app keeps working until Step 1.5 swaps the orchestrator.
|
|
||||||
- **Step 2.1 is the point of no return** — after removing LangChain, there's no local AI fallback.
|
|
||||||
- **Phase B (backend changes) must land before Phase 1.3–1.5** — Electron needs the bidirectional WS to talk to.
|
|
||||||
- **Phase 3 and Phase 4 are independent** — can be parallelized after Phase 2.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3 — Agent System: Config, Orchestration & Cloud Connectors
|
|
||||||
|
|
||||||
> **Objective:** Backend manages all agent configuration, scheduling, orchestration, and cloud data fetching. Two agent types: **Local Directory Agent** (backend triggers Electron to read files, then AI analyzes) and **Cloud Connector Agent** (backend fetches Gmail/Teams data directly, AI analyzes, pushes results to Electron via WS tool_call). All extracted items use existing WS tool infrastructure to insert into Electron's local DB with `is_ai_suggested=True`.
|
|
||||||
>
|
|
||||||
> **Electron Phase 3 plan:** `../adiuva/AI_REFACTOR_PLAN.md` Phase 3 section.
|
|
||||||
>
|
|
||||||
> **Electron UI status (2025):** Steps 3.6, 3.7, 3.8 of the Electron plan are ✅ complete. Agents are configured inside the Settings page (`/settings?section=agents`) — not a standalone route. The `JourneyDialog` (Step 3.8) is embedded inline in the Settings → Agents section. `LocalAgentConfigPanel` and `CloudAgentConfigPanel` (Step 3.7) are also inline. This affects the journey API contract (see Step 3.5 below).
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
Local Agent:
|
|
||||||
Scheduler/manual trigger ──► check device online ──► WS agent_run → Electron
|
|
||||||
──► Electron reads files ──► WS agent_data → Backend
|
|
||||||
──► Backend AI (prompt_template + file content) ──► WS tool_call(insert) → Electron
|
|
||||||
──► Electron persists with isAiSuggested=1
|
|
||||||
|
|
||||||
Cloud Agent:
|
|
||||||
Scheduler/manual trigger ──► Backend fetches Gmail/Teams (OAuth) ──► Backend AI analyzes
|
|
||||||
──► check device online ──► WS tool_call(insert) → Electron ──► Electron persists
|
|
||||||
```
|
|
||||||
|
|
||||||
**New WS frame types:**
|
|
||||||
|
|
||||||
| Direction | `type` | Payload |
|
|
||||||
|---|---|---|
|
|
||||||
| Server → Client | `agent_run` | `{ run_id, agent_id, config: { paths, file_extensions, prompt_template, data_types } }` |
|
|
||||||
| Client → Server | `agent_data` | `{ run_id, files: [{ path, name, content, metadata }] }` |
|
|
||||||
| Client → Server | `agent_complete` | `{ run_id, files_read, errors }` |
|
|
||||||
| Client → Server | `device_hello` | `{ device_id, agent_ids }` |
|
|
||||||
|
|
||||||
### Step 3.1 — Agent config tables
|
|
||||||
- [x] Add to `app/models.py`:
|
|
||||||
- **`LocalAgentConfig`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `device_id` str — identifies which Electron install this config belongs to
|
|
||||||
- `name` str
|
|
||||||
- `directory_paths` JSON — list of absolute paths on the device
|
|
||||||
- `data_types` JSON — which tables to extract to: `["tasks", "notes", "timelines", "projects"]`
|
|
||||||
- `prompt_template` text — user-configured via Chatbot Journey
|
|
||||||
- `file_extensions` JSON — e.g. `[".eml", ".txt", ".pdf", ".md"]`
|
|
||||||
- `schedule_cron` str — e.g. `"0 */6 * * *"` (every 6h)
|
|
||||||
- `enabled` bool (default True)
|
|
||||||
- `last_run_at` datetime nullable
|
|
||||||
- `created_at`, `updated_at` timestamps
|
|
||||||
- **`CloudAgentConfig`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `provider` str — enum: `gmail`, `teams`, `outlook`
|
|
||||||
- `name` str
|
|
||||||
- `data_types` JSON — same format as local
|
|
||||||
- `prompt_template` text
|
|
||||||
- `oauth_token_encrypted` text — Fernet-encrypted OAuth2 credentials
|
|
||||||
- `schedule_cron` str
|
|
||||||
- `enabled` bool (default True)
|
|
||||||
- `last_run_at` datetime nullable
|
|
||||||
- `filter_config` JSON — provider-specific: `{ labels: [], date_range: {from, to}, senders: [] }`
|
|
||||||
- `created_at`, `updated_at` timestamps
|
|
||||||
- **`AgentRunLog`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `agent_id` str — references LocalAgentConfig.id or CloudAgentConfig.id
|
|
||||||
- `agent_type` str — `local` or `cloud`
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `status` str — `running`, `success`, `error`, `partial`
|
|
||||||
- `items_processed` int (default 0)
|
|
||||||
- `items_created` int (default 0)
|
|
||||||
- `errors` JSON — list of error strings
|
|
||||||
- `started_at` datetime
|
|
||||||
- `completed_at` datetime nullable
|
|
||||||
- [x] Add Pydantic schemas to `app/schemas.py`:
|
|
||||||
- `LocalAgentConfigCreate`, `LocalAgentConfigUpdate`, `LocalAgentConfigResponse`
|
|
||||||
- `CloudAgentConfigCreate`, `CloudAgentConfigUpdate`, `CloudAgentConfigResponse`
|
|
||||||
- `AgentRunLogResponse`
|
|
||||||
- `AgentCatalogItem` — `{ type, name, description, config_schema }`
|
|
||||||
- `WsAgentRun`, `WsAgentData`, `WsAgentComplete`, `WsDeviceHello`
|
|
||||||
- [x] Generate Alembic migration
|
|
||||||
- **Files:** `app/models.py`, `app/schemas.py`, `alembic/versions/`
|
|
||||||
- **Outcome:** Agent config and run tracking tables in PostgreSQL.
|
|
||||||
|
|
||||||
### Step 3.2 — Agent CRUD API routes
|
|
||||||
- [x] Create `app/api/routes/agents.py`:
|
|
||||||
- `GET /api/v1/agents/catalog` — returns hardcoded agent type catalog:
|
|
||||||
- `local_directory`: "Watches local directories, extracts data from files using AI"
|
|
||||||
- `gmail`: "Scans Gmail inbox, extracts tasks/notes from emails"
|
|
||||||
- `teams`: "Monitors Teams messages, extracts action items"
|
|
||||||
- `outlook`: "Scans Outlook inbox, extracts tasks/notes"
|
|
||||||
- `GET /api/v1/agents/local` — list user's local agent configs
|
|
||||||
- `POST /api/v1/agents/local` — create local agent config
|
|
||||||
- Body: `{ name, device_id, directory_paths, data_types, prompt_template, file_extensions, schedule_cron }`
|
|
||||||
- Tier check: count enabled agents ≤ `batch_active` limit
|
|
||||||
- `PUT /api/v1/agents/local/{id}` — update config (ownership check)
|
|
||||||
- `DELETE /api/v1/agents/local/{id}` — delete config + associated run logs
|
|
||||||
- `GET /api/v1/agents/cloud` — list user's cloud agent configs
|
|
||||||
- `POST /api/v1/agents/cloud` — create cloud connector config
|
|
||||||
- Body: `{ provider, name, data_types, prompt_template, oauth_token_encrypted, schedule_cron, filter_config }`
|
|
||||||
- Tier check: same `batch_active` limit (local + cloud count together)
|
|
||||||
- `PUT /api/v1/agents/cloud/{id}` — update config
|
|
||||||
- `DELETE /api/v1/agents/cloud/{id}` — delete config + run logs
|
|
||||||
- `GET /api/v1/agents/runs` — query params: `agent_id`, `page`, `limit` → paginated run logs
|
|
||||||
- `POST /api/v1/agents/{id}/run` — manual trigger (dispatches to agent runner)
|
|
||||||
- All routes require JWT auth; ownership enforced on all mutations
|
|
||||||
- [x] Register router in `app/main.py`
|
|
||||||
- **Files:** `app/api/routes/agents.py`, `app/main.py`
|
|
||||||
- **Outcome:** Full CRUD for agent configs with tier-gated creation limits.
|
|
||||||
|
|
||||||
### Step 3.3 — Device WS endpoint
|
|
||||||
- [x] Create `app/api/routes/device_ws.py`:
|
|
||||||
- `WebSocket /api/v1/ws/device?token=<jwt>` — persistent connection from Electron
|
|
||||||
- On connect:
|
|
||||||
- Authenticate JWT
|
|
||||||
- Receive `device_hello` frame → extract `device_id`, `agent_ids`
|
|
||||||
- Store connection in `DeviceConnectionManager` (in-memory dict: `user_id → { ws, device_id }`)
|
|
||||||
- Check for overdue agent runs → trigger them immediately
|
|
||||||
- Message loop:
|
|
||||||
- `agent_data` → route to active agent run handler
|
|
||||||
- `agent_complete` → finalize agent run
|
|
||||||
- `tool_result` → route to pending tool call (same pattern as chat WS)
|
|
||||||
- `pong` → heartbeat ack
|
|
||||||
- On disconnect:
|
|
||||||
- Remove from `DeviceConnectionManager`
|
|
||||||
- Mark any in-progress agent runs as `error` with "device disconnected"
|
|
||||||
- Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s
|
|
||||||
- [x] Create `app/core/device_manager.py`:
|
|
||||||
- `DeviceConnectionManager` (singleton):
|
|
||||||
- `register(user_id, device_id, ws)` — stores active connection
|
|
||||||
- `unregister(user_id)` — removes connection
|
|
||||||
- `get_ws(user_id) -> WebSocket | None` — returns active WS if device is online
|
|
||||||
- `is_online(user_id, device_id=None) -> bool` — optionally checks specific device
|
|
||||||
- `send_frame(user_id, frame: dict)` — sends JSON frame to device
|
|
||||||
- **Files:** `app/api/routes/device_ws.py`, `app/core/device_manager.py`, `app/main.py`
|
|
||||||
- **Outcome:** Backend maintains persistent WS connections to Electron devices for agent triggers.
|
|
||||||
|
|
||||||
### Step 3.4 — Agent run orchestrator
|
|
||||||
- [x] Create `app/core/agent_runner.py`:
|
|
||||||
- `async run_local_agent(user_id, config: LocalAgentConfig, device_mgr: DeviceConnectionManager)`:
|
|
||||||
1. Check device is online with matching `device_id` → abort if offline
|
|
||||||
2. Create `AgentRunLog` with `status=running`
|
|
||||||
3. Send `WsAgentRun` frame to Electron with config (paths, extensions, prompt)
|
|
||||||
4. Await `WsAgentData` frames — collect file contents
|
|
||||||
5. Await `WsAgentComplete` frame — Electron signals done reading
|
|
||||||
6. For each file: call LLM with `prompt_template` + file content → extract structured items
|
|
||||||
7. For each extracted item: send `WsToolCall(insert, table, data)` to Electron → await `WsToolResult`
|
|
||||||
- All inserts include `is_ai_suggested=True, is_approved=False`
|
|
||||||
8. Update `AgentRunLog`: `status=success`, `items_processed`, `items_created`
|
|
||||||
- `async run_cloud_agent(user_id, config: CloudAgentConfig, device_mgr: DeviceConnectionManager)`:
|
|
||||||
1. Check device is online → abort if offline (results must push to Electron)
|
|
||||||
2. Create `AgentRunLog` with `status=running`
|
|
||||||
3. Decrypt OAuth credentials from `config.oauth_token_encrypted`
|
|
||||||
4. Fetch data from cloud provider (Step 3.6):
|
|
||||||
- Gmail: `google-api-python-client` + `filter_config` label/date filters
|
|
||||||
- Teams: `msgraph-sdk` + channel/date filters
|
|
||||||
- Outlook: `msgraph-sdk` + folder/date filters
|
|
||||||
5. For each item: call LLM with `prompt_template` + email/message content → extract structured items
|
|
||||||
6. For each extracted item: send `WsToolCall(insert)` to Electron → await `WsToolResult`
|
|
||||||
7. Update `AgentRunLog`
|
|
||||||
- `async trigger_pending_runs(user_id, device_id, device_mgr)`:
|
|
||||||
- Called when Electron connects (after `device_hello`)
|
|
||||||
- Queries all enabled agent configs where `last_run_at + schedule_interval < now()`
|
|
||||||
- For local agents: only triggers if `config.device_id == device_id`
|
|
||||||
- For cloud agents: triggers regardless of device (any connected device can receive results)
|
|
||||||
- Executes runs sequentially (one at a time to avoid overwhelming the WS)
|
|
||||||
- Error handling: on any failure, update `AgentRunLog` with `status=error` + error details
|
|
||||||
- [x] Wire `POST /agents/{id}/run` endpoint to dispatch background task via `asyncio.create_task()`
|
|
||||||
- [x] Replace `_trigger_pending_runs_stub` in `device_ws.py` with real `trigger_pending_runs` call
|
|
||||||
- [x] Add `croniter>=3.0.0` to `requirements.txt`
|
|
||||||
- [x] 23 unit + integration tests covering all code paths
|
|
||||||
- **Files:** `app/core/agent_runner.py`, `app/api/routes/agents.py`, `app/api/routes/device_ws.py`, `requirements.txt`, `tests/test_agent_runner.py`
|
|
||||||
- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls — stub until Step 3.6).
|
|
||||||
|
|
||||||
### Step 3.5 — Chatbot Journey endpoint
|
|
||||||
- [x] Create `app/api/routes/agent_setup.py`:
|
|
||||||
- `POST /api/v1/agents/journey/start`:
|
|
||||||
- Body: `{ agent_type: "local"|"cloud", agent_id: str | None }`
|
|
||||||
- `agent_type`: which kind of agent this journey configures.
|
|
||||||
- `agent_id`: optional — if provided, the session is pre-seeded with the existing agent's `prompt_template` so the user can refine it. If absent, fresh journey.
|
|
||||||
- **No `data_types` field** — data types are determined through the conversation itself, not sent upfront.
|
|
||||||
- Creates a journey session (in-memory or Redis-backed)
|
|
||||||
- Returns first AI message: contextual question based on agent type
|
|
||||||
- Local: "What kind of files are in the directories you want to monitor? (emails, documents, logs, etc.)"
|
|
||||||
- Cloud: "What kind of emails/messages should I look for? (client communications, invoices, meeting notes, etc.)"
|
|
||||||
- Response: `{ session_id, message, done: false }`
|
|
||||||
- **Electron note:** `proxyPost` auto-converts camelCase keys to snake_case. Electron sends `{ agentType, agentId }` → backend receives `{ agent_type, agent_id }`.
|
|
||||||
- `POST /api/v1/agents/journey/message`:
|
|
||||||
- Body: `{ session_id, message }`
|
|
||||||
- AI processes user's answer, asks follow-up questions (max 5 turns)
|
|
||||||
- System prompt: "You are configuring a data extraction agent for a freelancer. Ask about file format, what data to extract (tasks, notes, timelines), naming conventions, priority rules, and any special mapping. After 3-5 questions, generate a detailed prompt_template."
|
|
||||||
- When AI determines enough context: `{ session_id, message: "Here's your configuration...", done: true, prompt_template: "..." }`
|
|
||||||
- The `prompt_template` is a structured instruction for the extraction LLM (e.g. "Extract tasks from email. Subject becomes task title. If body contains 'urgent' or 'ASAP', set priority to 'high'. Extract due dates if mentioned.")
|
|
||||||
- **Electron note:** `toCamelCase` converts the response → Electron reads `promptTemplate` from the final message and auto-fills the agent config panel. User clicks "Save & apply" which calls `agent.local.update` / `agent.cloud.update` tRPC mutation.
|
|
||||||
- **Files:** `app/api/routes/agent_setup.py`, `app/main.py`
|
|
||||||
- **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅
|
|
||||||
|
|
||||||
### Step 3.6 — Cloud provider integrations
|
|
||||||
- [x] Create `app/integrations/gmail.py`:
|
|
||||||
- `GmailClient`:
|
|
||||||
- `__init__(oauth_token)` — initializes Google API client
|
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]`
|
|
||||||
- `EmailMessage`: `{ id, subject, sender, body_text, date, labels }`
|
|
||||||
- Handles token refresh via Google OAuth2 refresh flow
|
|
||||||
- Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders`
|
|
||||||
- [x] Create `app/integrations/ms_graph.py`:
|
|
||||||
- `MSGraphClient`:
|
|
||||||
- `__init__(oauth_token)` — initializes MS Graph client
|
|
||||||
- `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook)
|
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams)
|
|
||||||
- `ChatMessage`: `{ id, content, sender, channel, date }`
|
|
||||||
- Handles token refresh via MSAL
|
|
||||||
- [x] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient`
|
|
||||||
- **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal`
|
|
||||||
- **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py`
|
|
||||||
- **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams.
|
|
||||||
|
|
||||||
### Step 3.7 — Agent scheduler
|
|
||||||
- [ ] Create `app/core/agent_scheduler.py`:
|
|
||||||
- Uses `APScheduler` (or simple asyncio loop) to check agent schedules
|
|
||||||
- Every 60s: query enabled agents where `last_run_at + cron_interval < now()`
|
|
||||||
- For each due agent:
|
|
||||||
- Check if user's device is online via `DeviceConnectionManager`
|
|
||||||
- If online: dispatch to `agent_runner`
|
|
||||||
- If offline: skip (will trigger on next `device_hello`)
|
|
||||||
- Locks: use PostgreSQL advisory locks to prevent duplicate runs in multi-instance deployments
|
|
||||||
- [ ] Integrate with FastAPI lifespan (start scheduler on app startup, shutdown gracefully)
|
|
||||||
- **Dependencies:** `apscheduler>=4.0`
|
|
||||||
- **Files:** `app/core/agent_scheduler.py`, `app/main.py`
|
|
||||||
- **Outcome:** Agents run automatically on their configured schedules.
|
|
||||||
|
|
||||||
### Step 3.8 — OAuth flow endpoints
|
|
||||||
- [ ] Create `app/api/routes/oauth.py`:
|
|
||||||
- `GET /api/v1/oauth/{provider}/authorize` — returns OAuth authorization URL
|
|
||||||
- Gmail: Google OAuth2 with `gmail.readonly` scope
|
|
||||||
- Outlook/Teams: MS identity platform with `Mail.Read`, `ChannelMessage.Read.All` scopes
|
|
||||||
- `GET /api/v1/oauth/{provider}/callback` — handles OAuth redirect
|
|
||||||
- Exchanges auth code for access + refresh tokens
|
|
||||||
- Encrypts tokens with Fernet (server-side key from settings)
|
|
||||||
- Returns encrypted token blob for storage in `CloudAgentConfig.oauth_token_encrypted`
|
|
||||||
- `POST /api/v1/oauth/{provider}/refresh` — refresh expired OAuth token
|
|
||||||
- **Files:** `app/api/routes/oauth.py`, `app/main.py`
|
|
||||||
- **Outcome:** Users can connect Gmail/Teams/Outlook accounts securely.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Phase 3 — Verification
|
|
||||||
|
|
||||||
| # | Scenario | Expected |
|
|
||||||
|---|---|---|
|
|
||||||
| 1 | **Agent CRUD** | Create/read/update/delete local and cloud configs; tier limits enforced (free=2, pro=10) |
|
|
||||||
| 2 | **WS device connect** | Electron connects → `device_hello` → backend stores connection → triggers overdue runs |
|
|
||||||
| 3 | **Local agent run** | Backend sends `agent_run` → Electron reads files → `agent_data` → backend AI extracts → `tool_call(insert)` → Electron persists with `isAiSuggested=1` |
|
|
||||||
| 4 | **Cloud agent run** | Backend fetches Gmail → AI extracts tasks → `tool_call(insert)` → Electron persists |
|
|
||||||
| 5 | **Device binding** | Local agent config with `device_id=A` only triggers when device A is connected |
|
|
||||||
| 6 | **Chatbot Journey** | Start journey → 3-5 Q&A turns → produces valid `prompt_template` |
|
|
||||||
| 7 | **Schedule** | Agent with `schedule_cron="0 */6 * * *"` runs every 6h when device is online |
|
|
||||||
| 8 | **Offline resilience** | Device offline → runs skipped → device reconnects → overdue runs trigger immediately |
|
|
||||||
| 9 | **OAuth flow** | Gmail authorize → callback → token encrypted → stored in config → fetch emails works |
|
|
||||||
|
|
||||||
### Phase 3 — New Dependencies
|
|
||||||
|
|
||||||
| Package | Purpose |
|
|
||||||
|---|---|
|
|
||||||
| `google-api-python-client` | Gmail API access |
|
|
||||||
| `google-auth-oauthlib` | Gmail OAuth2 flow |
|
|
||||||
| `msgraph-sdk` | Outlook + Teams API access |
|
|
||||||
| `msal` | MS identity platform auth |
|
|
||||||
| `apscheduler>=4.0` | Agent scheduling |
|
|
||||||
| `cryptography` (Fernet) | OAuth token encryption at rest |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ~~Phase 5 — Shared Memory~~ (SUPERSEDED)
|
|
||||||
|
|
||||||
> **This phase has been fully replaced by `V3_MIGRATION_PLAN.md`.**
|
|
||||||
>
|
|
||||||
> - Chat WS fix → V3 Step 5 (Unified WS Handler — single multiplexed socket)
|
|
||||||
> - Agent memory → V3 Steps 6–7 (Cloud-side MemGPT-style memory in PostgreSQL + pgvector, encrypted at rest with per-user Fernet key)
|
|
||||||
>
|
|
||||||
> The on-device KV approach (Electron SQLite `agent_memory` table) is no longer the target architecture.
|
|
||||||
> See `V3_MIGRATION_PLAN.md` for the current plan.
|
|
||||||
572
BACKEND_PLAN.md
572
BACKEND_PLAN.md
@@ -1,572 +0,0 @@
|
|||||||
# Backend Plan — Adiuva Cloud API
|
|
||||||
|
|
||||||
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
|
||||||
>
|
|
||||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
|
|
||||||
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── app/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── main.py # FastAPI entry + CORS + lifespan + router includes
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── agent_registry.py # Base classes + singleton registry
|
|
||||||
│ │ ├── orchestrator.py # LLM-based intent router
|
|
||||||
│ │ ├── execution_plan.py # Plan builder + cache
|
|
||||||
│ │ └── plugin_loader.py # Dynamic agent loading
|
|
||||||
│ ├── agents/ # Chat agents (proprietary logic + prompts)
|
|
||||||
│ │ ├── __init__.py # Auto-registers all agents
|
|
||||||
│ │ ├── task_agent.py
|
|
||||||
│ │ ├── calendar_agent.py
|
|
||||||
│ │ ├── email_agent.py
|
|
||||||
│ │ └── analytics_agent.py
|
|
||||||
│ ├── api/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── routes/
|
|
||||||
│ │ │ ├── __init__.py
|
|
||||||
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
|
||||||
│ │ │ ├── plans.py # GET /plans/playbook
|
|
||||||
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
|
|
||||||
│ │ │ ├── vectors.py # Upsert/search cloud vector store
|
|
||||||
│ │ │ ├── backup.py # PUT/GET /backup
|
|
||||||
│ │ │ ├── plugins.py # Plugin marketplace
|
|
||||||
│ │ │ ├── auth.py # Register/login/refresh
|
|
||||||
│ │ │ └── billing.py # Checkout/webhook/subscription
|
|
||||||
│ │ └── middleware/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── auth.py # JWT validation
|
|
||||||
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
|
||||||
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
|
||||||
│ ├── storage/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
|
|
||||||
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
|
|
||||||
│ │ └── encryption.py # Integrity verification only — NO decryption
|
|
||||||
│ ├── marketplace/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
|
|
||||||
│ │ ├── plugin_review.py # Review queue + approval workflow
|
|
||||||
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
|
|
||||||
│ ├── billing/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
|
||||||
│ │ └── tier_manager.py # Feature matrix per tier
|
|
||||||
│ └── config/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ └── settings.py # Pydantic BaseSettings (env-based)
|
|
||||||
├── tests/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── conftest.py # Fixtures: test client, mock agents, mock LLM
|
|
||||||
│ ├── test_orchestrator.py
|
|
||||||
│ ├── test_agents.py
|
|
||||||
│ ├── test_auth.py
|
|
||||||
│ ├── test_backup.py
|
|
||||||
│ ├── test_storage.py
|
|
||||||
│ └── test_plugins.py
|
|
||||||
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
|
|
||||||
│ ├── alembic.ini
|
|
||||||
│ └── versions/
|
|
||||||
├── requirements.txt
|
|
||||||
├── Dockerfile
|
|
||||||
├── docker-compose.yml # App + PostgreSQL + Redis (dev)
|
|
||||||
├── .env.example
|
|
||||||
└── README.md
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step-by-Step Implementation
|
|
||||||
|
|
||||||
### Step 1 — Project scaffolding ✅
|
|
||||||
- [x] Initialize repo with the directory structure above
|
|
||||||
- [x] Write `requirements.txt`:
|
|
||||||
```
|
|
||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
langchain>=0.3.0
|
|
||||||
langchain-openai>=0.3.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
stripe>=11.0.0
|
|
||||||
boto3>=1.35.0
|
|
||||||
slowapi>=0.1.9
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
alembic>=1.14.0
|
|
||||||
bcrypt>=4.2.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
httpx>=0.28.0
|
|
||||||
websockets>=14.0
|
|
||||||
pytest>=8.0.0
|
|
||||||
pytest-asyncio>=0.24.0
|
|
||||||
```
|
|
||||||
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
|
||||||
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
|
||||||
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
|
||||||
- [x] Write `.env.example`
|
|
||||||
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
|
|
||||||
|
|
||||||
### Step 2 — Pydantic schemas (API contracts) ✅
|
|
||||||
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
|
||||||
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
|
||||||
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
|
||||||
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
|
||||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
|
|
||||||
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
|
||||||
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
|
||||||
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
|
||||||
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
|
||||||
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
|
||||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
|
||||||
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
|
|
||||||
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
|
|
||||||
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
|
|
||||||
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
|
|
||||||
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
|
|
||||||
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
|
|
||||||
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
|
|
||||||
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
|
|
||||||
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
|
|
||||||
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
|
|
||||||
- `PluginInstallRequest`: `plugin_id: str`
|
|
||||||
- **Outcome:** All request/response models defined and validated.
|
|
||||||
|
|
||||||
### Step 3 — Agent Registry + base classes ✅
|
|
||||||
- [x] `app/core/agent_registry.py`:
|
|
||||||
- `BaseAgent(ABC)`:
|
|
||||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
|
||||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
|
||||||
- `ChatAgent(BaseAgent)`:
|
|
||||||
- Abstract `async handle(query: str, context: dict) -> str`
|
|
||||||
- Abstract `get_tools() -> list` (LangChain tool definitions)
|
|
||||||
- Concrete `_tool_loop(llm, messages, tools, max_iter=5) -> str` — shared tool-calling loop
|
|
||||||
- `AgentRegistry` (singleton):
|
|
||||||
- `_agents: dict[str, ChatAgent]`
|
|
||||||
- `register(agent_class)` — decorator pattern
|
|
||||||
- `get(name) -> ChatAgent`
|
|
||||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
|
||||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
|
||||||
- [x] Unit tests: register, get, list, call_agent with mock
|
|
||||||
- **Outcome:** Pluggable agent framework.
|
|
||||||
|
|
||||||
### Step 4 — Orchestrator ✅
|
|
||||||
- [x] `app/core/orchestrator.py`:
|
|
||||||
- `async classify_intent(message, context, registry) -> str`:
|
|
||||||
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
|
||||||
- Uses gpt-4o-mini via LangChain for low latency
|
|
||||||
- Falls back to `task_agent` if no clear match
|
|
||||||
- `async route_single(agent_name, message, context) -> ChatResponse`:
|
|
||||||
- Instantiates agent from registry
|
|
||||||
- Calls `agent.handle(message, context)`
|
|
||||||
- Returns response + any actions the agent produced
|
|
||||||
- `async route_pipeline(agent_names, message, context) -> ChatResponse`:
|
|
||||||
- Executes agents in sequence
|
|
||||||
- Each agent receives `{...context, previous_results: [...]}`
|
|
||||||
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
|
||||||
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
|
||||||
- Main entry point
|
|
||||||
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
|
|
||||||
- Classifies intent
|
|
||||||
- If `execution_mode == 'direct'`: route + return response
|
|
||||||
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
|
||||||
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
|
||||||
- Same as orchestrate but yields tokens for WebSocket streaming
|
|
||||||
- [x] Integration tests with mocked LLM and mocked agents
|
|
||||||
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
|
||||||
|
|
||||||
### Step 5 — Execution Plan generator ✅
|
|
||||||
- [x] `app/core/execution_plan.py`:
|
|
||||||
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
|
||||||
- `ExecutionPlanBuilder`:
|
|
||||||
- `add_step(action, params) -> self`
|
|
||||||
- `add_llm_step(template_id, variables) -> self`
|
|
||||||
- `add_data_step(action, data_from_step) -> self`
|
|
||||||
- `build() -> ExecutionPlan` — validates step references
|
|
||||||
- `PlanCache`:
|
|
||||||
- In-memory LRU (maxsize=1000)
|
|
||||||
- `cache_plan(key, plan)`, `get_plan(key)`, `get_all_playbooks() -> list[ExecutionPlan]`
|
|
||||||
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
|
||||||
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
|
||||||
|
|
||||||
### Step 6 — Chat Agents ✅
|
|
||||||
- [x] `app/agents/task_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
|
|
||||||
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
|
|
||||||
- Accepts flexible context; sentinel `-1` for optional integer update fields
|
|
||||||
- [x] `app/agents/timeline_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages project timelines (milestones): list, create, update, delete"
|
|
||||||
- Tools (4): `list_timelines(project_id)`, `create_timeline(project_id, title, date, is_ai_suggested, is_approved)`, `update_timeline(timeline_id, ...)`, `delete_timeline(timeline_id)`
|
|
||||||
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
|
|
||||||
- [x] `app/agents/project_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
|
|
||||||
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
|
|
||||||
- [x] `app/agents/note_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages notes: list, get, create, update, delete"
|
|
||||||
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
|
|
||||||
- content is Markdown; `get_note` should be called before update to preserve existing content
|
|
||||||
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
|
|
||||||
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
|
|
||||||
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Timelines, Projects, Notes), all registered and tested.
|
|
||||||
|
|
||||||
### Step 7 — Storage Layer ✅
|
|
||||||
- [x] `app/storage/blob_store.py`:
|
|
||||||
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
|
|
||||||
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
|
|
||||||
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
|
|
||||||
- [x] `app/storage/vector_store.py`:
|
|
||||||
- `VectorStore`: `async upsert`, `async search`, `async delete`
|
|
||||||
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
|
|
||||||
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
|
|
||||||
- ANN on encrypted data: known accuracy trade-off, documented
|
|
||||||
- [x] `app/storage/encryption.py`:
|
|
||||||
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
|
|
||||||
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
|
|
||||||
- Backend NEVER holds decryption keys
|
|
||||||
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
|
|
||||||
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
|
|
||||||
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
|
|
||||||
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
|
|
||||||
|
|
||||||
### Step 8 — API Routes ✅
|
|
||||||
|
|
||||||
#### 8a — Chat endpoint
|
|
||||||
- [x] `app/api/routes/chat.py`:
|
|
||||||
- `POST /api/v1/chat`:
|
|
||||||
- Request: `ChatRequest`
|
|
||||||
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
|
||||||
- Response: `ChatResponse` or `ExecutionPlan`
|
|
||||||
- `WebSocket /api/v1/chat/stream`:
|
|
||||||
- Client sends `ChatRequest` as first JSON frame
|
|
||||||
- Server yields token strings via `orchestrate_stream()`
|
|
||||||
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
|
||||||
- Heartbeat ping every 30s to keep connection alive
|
|
||||||
|
|
||||||
#### 8b — Plans endpoint
|
|
||||||
- [x] `app/api/routes/plans.py`:
|
|
||||||
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
|
||||||
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
|
||||||
|
|
||||||
#### 8c — Storage endpoint (cloud records)
|
|
||||||
- [x] `app/api/routes/storage.py`:
|
|
||||||
- `POST /api/v1/storage/records`: Create encrypted record
|
|
||||||
- Request: `StorageRecordCreate`
|
|
||||||
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
|
|
||||||
- Response: `{id: str, created_at: int}`
|
|
||||||
- `GET /api/v1/storage/records`: List record metadata (no blobs)
|
|
||||||
- Query params: `table: str`, `page: int`, `limit: int`
|
|
||||||
- Response: `list[{id, table, checksum, created_at, updated_at}]`
|
|
||||||
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
|
|
||||||
- Response: blob bytes + `X-Checksum` header
|
|
||||||
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
|
|
||||||
- Request: `StorageRecordUpdate`
|
|
||||||
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
|
|
||||||
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
|
|
||||||
|
|
||||||
#### 8d — Vectors endpoint (cloud vector store)
|
|
||||||
- [x] `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/upsert`:
|
|
||||||
- Request: `VectorUpsertRequest`
|
|
||||||
- Verifies checksums, delegates to `VectorStore.upsert()`
|
|
||||||
- Response: `{upserted: int}`
|
|
||||||
- `POST /api/v1/storage/vectors/search`:
|
|
||||||
- Request: `VectorSearchRequest`
|
|
||||||
- Delegates to `VectorStore.search()`
|
|
||||||
- Response: `VectorSearchResponse`
|
|
||||||
- `DELETE /api/v1/storage/vectors`:
|
|
||||||
- Request: `{ids: list[str]}`
|
|
||||||
|
|
||||||
#### 8e — Backup endpoint
|
|
||||||
- [x] `app/api/routes/backup.py`:
|
|
||||||
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
|
||||||
- Free: 0 (no backup)
|
|
||||||
- Pro: 5 GB
|
|
||||||
- Power: 25 GB
|
|
||||||
- Team: unlimited
|
|
||||||
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
|
||||||
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
|
||||||
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
|
||||||
|
|
||||||
#### 8f — Plugins endpoint
|
|
||||||
- [x] `app/api/routes/plugins.py`:
|
|
||||||
- `GET /api/v1/plugins`:
|
|
||||||
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
|
|
||||||
- Response: `PluginListResponse`
|
|
||||||
- Available from Power tier and above
|
|
||||||
- `GET /api/v1/plugins/{id}`:
|
|
||||||
- Response: `PluginManifest` + ratings + install count
|
|
||||||
- `POST /api/v1/plugins/{id}/install`:
|
|
||||||
- Request: `PluginInstallRequest`
|
|
||||||
- Records installation for the user (billing tracking, analytics)
|
|
||||||
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
|
|
||||||
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
|
|
||||||
- `DELETE /api/v1/plugins/{id}/install`:
|
|
||||||
- Unregisters installation
|
|
||||||
|
|
||||||
#### 8g — Auth endpoint
|
|
||||||
- [x] `app/api/routes/auth.py`:
|
|
||||||
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
|
||||||
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
|
||||||
|
|
||||||
#### 8h — Billing endpoint
|
|
||||||
- [x] `app/api/routes/billing.py`:
|
|
||||||
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
|
||||||
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
|
||||||
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
|
||||||
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
|
||||||
|
|
||||||
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
|
|
||||||
|
|
||||||
### Step 9 — Middleware
|
|
||||||
|
|
||||||
#### 9a — Auth middleware
|
|
||||||
- [x] `app/api/middleware/auth.py`:
|
|
||||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
|
||||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
|
||||||
- Raises `401` on invalid/expired token
|
|
||||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
|
||||||
|
|
||||||
#### 9b — Rate limiter
|
|
||||||
- [x] `app/api/middleware/rate_limit.py`:
|
|
||||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
|
||||||
- Tier-based limits:
|
|
||||||
- Free: 20 req/min
|
|
||||||
- Pro: 60 req/min
|
|
||||||
- Power: 120 req/min
|
|
||||||
- Team: 200 req/seat/min
|
|
||||||
- Custom 429 response with `Retry-After` header
|
|
||||||
|
|
||||||
#### 9c — Sanitizer
|
|
||||||
- [x] `app/api/middleware/sanitizer.py`:
|
|
||||||
- Response middleware that scans response bodies
|
|
||||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
|
||||||
- Pattern-based detection + exact match against known prompt fingerprints
|
|
||||||
- Logs sanitization events for monitoring
|
|
||||||
|
|
||||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
|
||||||
|
|
||||||
### Step 10 — Plugin Marketplace ✅
|
|
||||||
- [x] `app/marketplace/plugin_registry.py`:
|
|
||||||
- `PluginRegistry`:
|
|
||||||
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
|
||||||
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
|
||||||
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
|
||||||
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
|
||||||
- `async reject_plugin(plugin_id, reason: str) -> None`
|
|
||||||
- [x] `app/marketplace/plugin_review.py`:
|
|
||||||
- `ReviewQueue`:
|
|
||||||
- `async get_pending() -> list[dict]`
|
|
||||||
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
|
||||||
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
|
||||||
- [x] `app/marketplace/revenue_share.py`:
|
|
||||||
- `RevenueShare`:
|
|
||||||
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
|
||||||
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
|
||||||
- `async get_earnings(developer_id, period) -> dict`
|
|
||||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
|
||||||
|
|
||||||
### Step 11 — Billing & Tier management ✅
|
|
||||||
- [x] `app/billing/stripe_service.py`:
|
|
||||||
- `create_checkout_session(user_id, tier) -> str`
|
|
||||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
|
||||||
- `get_subscription(user_id) -> dict | None`
|
|
||||||
- `cancel_subscription(user_id) -> None`
|
|
||||||
- [x] `app/billing/tier_manager.py`:
|
|
||||||
- `TierManager`:
|
|
||||||
- Feature matrix:
|
|
||||||
```python
|
|
||||||
FEATURES = {
|
|
||||||
'free': {
|
|
||||||
'agents': 3,
|
|
||||||
'batch_active': 2,
|
|
||||||
'cloud_storage_gb': 0,
|
|
||||||
'backup_gb': 0,
|
|
||||||
'providers': 1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'pro': {
|
|
||||||
'agents': -1, # unlimited
|
|
||||||
'batch_active': 10,
|
|
||||||
'cloud_storage_gb': 5,
|
|
||||||
'backup_gb': 5,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'power': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1, # unlimited
|
|
||||||
'cloud_storage_gb': 25,
|
|
||||||
'backup_gb': 25,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'team': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1,
|
|
||||||
'cloud_storage_gb': -1,
|
|
||||||
'backup_gb': -1,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- `get_tier(user_id) -> BillingTier`
|
|
||||||
- `check_feature(user_id, feature) -> bool`
|
|
||||||
- `get_rate_limit(tier) -> int`
|
|
||||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
|
||||||
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
|
|
||||||
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
|
|
||||||
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
|
|
||||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
|
||||||
|
|
||||||
### Step 12 — Database (auth/billing/marketplace only)
|
|
||||||
- [x] PostgreSQL schema via Alembic:
|
|
||||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
|
||||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
|
||||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
|
||||||
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
|
||||||
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
|
|
||||||
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
|
|
||||||
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
|
||||||
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
|
||||||
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
|
||||||
- [x] Initial Alembic migration
|
|
||||||
- [x] SQLAlchemy models in `app/models.py`
|
|
||||||
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
|
||||||
|
|
||||||
### Step 13 — Testing & deployment ✅
|
|
||||||
- [x] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
|
|
||||||
- [x] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
|
||||||
- [x] `tests/test_agents.py`: each agent with mocked tools
|
|
||||||
- [x] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
|
||||||
- [x] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
|
||||||
- [x] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
|
|
||||||
- [x] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
|
|
||||||
- [x] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
|
||||||
- [x] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
|
||||||
- **Outcome:** Fully tested, deployable backend.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API Contract Summary
|
|
||||||
|
|
||||||
| Method | Endpoint | Auth | Request | Response |
|
|
||||||
|--------|----------|------|---------|----------|
|
|
||||||
| POST | `/api/v1/auth/register` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/login` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/refresh` | No | `{refresh_token}` | `AuthTokens` |
|
|
||||||
| GET | `/api/v1/auth/me` | JWT | — | `UserProfile` |
|
|
||||||
| POST | `/api/v1/chat` | JWT | `ChatRequest` | `ChatResponse \| ExecutionPlan` |
|
|
||||||
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
|
||||||
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
|
||||||
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
|
||||||
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
|
|
||||||
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
|
|
||||||
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
|
|
||||||
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
|
|
||||||
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
|
|
||||||
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
|
|
||||||
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
|
||||||
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
|
||||||
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
|
|
||||||
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
|
|
||||||
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
|
|
||||||
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
|
||||||
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
|
||||||
| DELETE | `/api/v1/billing/subscription` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/health` | No | — | `{status, version}` |
|
|
||||||
| GET | `/api/v1/agents/catalog` | JWT | — | `AgentCatalogItem[]` |
|
|
||||||
| GET | `/api/v1/agents/local` | JWT | — | `LocalAgentConfigResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/local` | JWT | `LocalAgentConfigCreate` | `LocalAgentConfigResponse` |
|
|
||||||
| PUT | `/api/v1/agents/local/{id}` | JWT | `LocalAgentConfigUpdate` | `LocalAgentConfigResponse` |
|
|
||||||
| DELETE | `/api/v1/agents/local/{id}` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/agents/cloud` | JWT | — | `CloudAgentConfigResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/cloud` | JWT | `CloudAgentConfigCreate` | `CloudAgentConfigResponse` |
|
|
||||||
| PUT | `/api/v1/agents/cloud/{id}` | JWT | `CloudAgentConfigUpdate` | `CloudAgentConfigResponse` |
|
|
||||||
| DELETE | `/api/v1/agents/cloud/{id}` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/agents/runs` | JWT | `?agent_id&page&limit` | `AgentRunLogResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/{id}/run` | JWT | — | `{ok: true, run_id}` |
|
|
||||||
| POST | `/api/v1/agents/journey/start` | JWT | `{agent_type, data_types}` | `{session_id, message, done}` |
|
|
||||||
| POST | `/api/v1/agents/journey/message` | JWT | `{session_id, message}` | `{session_id, message, done, prompt_template?}` |
|
|
||||||
| GET | `/api/v1/oauth/{provider}/authorize` | JWT | — | `{authorization_url}` |
|
|
||||||
| GET | `/api/v1/oauth/{provider}/callback` | — | OAuth code | `{encrypted_token}` |
|
|
||||||
| WS | `/api/v1/ws/device` | JWT | `device_hello` (first frame) | Agent trigger + tool_call frames |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stack
|
|
||||||
|
|
||||||
| Layer | Technology |
|
|
||||||
|-------|-----------|
|
|
||||||
| Framework | FastAPI + Uvicorn |
|
|
||||||
| LLM | LangChain + langchain-openai |
|
|
||||||
| Auth | PyJWT + bcrypt + OAuth2 |
|
|
||||||
| Billing | stripe-python + Stripe Connect |
|
|
||||||
| Blob storage | boto3 (S3) |
|
|
||||||
| Vector store | Pinecone or Qdrant (configurable) |
|
|
||||||
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
|
||||||
| Rate limiting | slowapi |
|
|
||||||
| Cloud integrations | google-api-python-client, msgraph-sdk, msal |
|
|
||||||
| Agent scheduling | APScheduler |
|
|
||||||
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
|
|
||||||
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3 — New Files
|
|
||||||
|
|
||||||
| File | Purpose |
|
|
||||||
|---|---|
|
|
||||||
| `app/models.py` | Add `LocalAgentConfig`, `CloudAgentConfig`, `AgentRunLog` models |
|
|
||||||
| `app/schemas.py` | Add agent config schemas + WS agent frame types |
|
|
||||||
| `app/api/routes/agents.py` | Agent CRUD endpoints (catalog, local, cloud, runs, manual trigger) |
|
|
||||||
| `app/api/routes/agent_setup.py` | Chatbot Journey endpoints (start + message) |
|
|
||||||
| `app/api/routes/device_ws.py` | Persistent device WS endpoint (`/api/v1/ws/device`) |
|
|
||||||
| `app/api/routes/oauth.py` | OAuth authorize/callback for Gmail, Teams, Outlook |
|
|
||||||
| `app/core/agent_runner.py` | Agent run orchestration — local (WS file request) + cloud (API fetch) |
|
|
||||||
| `app/core/device_manager.py` | `DeviceConnectionManager` — tracks active Electron WS connections |
|
|
||||||
| `app/core/agent_scheduler.py` | Periodic scheduler for agent cron triggers |
|
|
||||||
| `app/integrations/gmail.py` | Gmail API client (fetch messages with filters) |
|
|
||||||
| `app/integrations/ms_graph.py` | MS Graph client for Outlook emails + Teams messages |
|
|
||||||
| `app/integrations/__init__.py` | Provider factory |
|
|
||||||
|
|
||||||
> **Full Phase 3 step-by-step plan:** See `AI_REFACTOR_PLAN.md` Phase 3 section.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Development Rules
|
|
||||||
|
|
||||||
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
|
|
||||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
|
|
||||||
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
|
|
||||||
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
|
||||||
5. **Type hints everywhere.** All functions have full type annotations.
|
|
||||||
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
|
||||||
7. **Structured logging.** JSON logs with request ID correlation.
|
|
||||||
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
|
|
||||||
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
|
||||||
@@ -1,353 +0,0 @@
|
|||||||
# V3 Migration Plan — Multi-Agent AI Productivity App
|
|
||||||
|
|
||||||
> Incremental migration from current architecture to v3.
|
|
||||||
> Each step is self-contained, testable, and backwards-compatible.
|
|
||||||
> No BYOK — server manages all LLM keys.
|
|
||||||
> Memory encryption: server-side per-user Fernet key (Option A).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## General Rules
|
|
||||||
|
|
||||||
**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes:
|
|
||||||
- Old functions/methods that are superseded by new ones
|
|
||||||
- Deprecated imports or modules
|
|
||||||
- Dead code paths
|
|
||||||
- Old test files no longer needed
|
|
||||||
|
|
||||||
This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Decisions Log
|
|
||||||
|
|
||||||
| Topic | Decision |
|
|
||||||
|---|---|
|
|
||||||
| WS topology | Single multiplexed socket (merge chat into device WS) |
|
|
||||||
| LLM keys | Server-managed only, no user key passthrough |
|
|
||||||
| Memory encryption | Per-user server-generated Fernet key, encrypted at rest, decrypted in-memory |
|
|
||||||
| device_manager | Already multi-user correct (keyed by user_id), no structural change |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 1 — WS Frame Protocol (schemas.py)
|
|
||||||
|
|
||||||
**Goal**: Define the v3 frame vocabulary so all subsequent steps can import it.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/schemas.py` — Add to `WsFrameType` enum:
|
|
||||||
- `home_request`, `floating_request`
|
|
||||||
- `stream_start`, `stream_text`, `stream_block`, `stream_end`
|
|
||||||
- `floating_domain`
|
|
||||||
- `data_request`, `data_response`, `mutation`
|
|
||||||
- Add Pydantic models:
|
|
||||||
- `WsHomeRequest(type, message, conversation_history?)`
|
|
||||||
- `WsFloatingRequest(type, message, scope: {type, id?})`
|
|
||||||
- `WsStreamStart(type, request_id)`
|
|
||||||
- `WsStreamText(type, request_id, chunk)`
|
|
||||||
- `WsStreamBlock(type, request_id, block_type, data)`
|
|
||||||
- `WsStreamEnd(type, request_id, mutations?)`
|
|
||||||
- `WsFloatingDomain(type, request_id, domain)`
|
|
||||||
- Keep all existing frame types (backward compat).
|
|
||||||
|
|
||||||
**Files touched**: `app/schemas.py`
|
|
||||||
|
|
||||||
**Test**: Unit test that validates each new model serializes/deserializes correctly.
|
|
||||||
```
|
|
||||||
pytest tests/test_schemas_v3.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 1 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-1: add v3 ws frame protocol (schemas.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 2 — Agent Streaming + Tool Result Capture (agent_registry.py, agents/)
|
|
||||||
|
|
||||||
**Goal**: Agents can stream LLM tokens and expose structured tool results.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/core/agent_registry.py`:
|
|
||||||
- Add `_tool_loop_stream()` to `ChatAgent` — same logic as `_tool_loop()` but the **final** LLM call (when no more tool calls) uses `llm.astream()` and yields tokens.
|
|
||||||
- Add `self.tool_results: list[dict]` attribute to `ChatAgent.__init__()`.
|
|
||||||
- In both `_tool_loop` and `_tool_loop_stream`, capture raw `execute_on_client` results when tools run (store in `self.tool_results`).
|
|
||||||
- `app/agents/*.py` — Each agent's tools already return text summaries. No change to tools. The raw data capture happens at the `_tool_loop` level by intercepting `ToolMessage` content that comes from `execute_on_client`.
|
|
||||||
|
|
||||||
**Files touched**: `app/core/agent_registry.py`
|
|
||||||
|
|
||||||
**Test**: Unit test with mocked LLM that verifies `_tool_loop_stream()` yields tokens and `agent.tool_results` contains structured data after a tool call.
|
|
||||||
```
|
|
||||||
pytest tests/test_agent_streaming.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 2 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-2: add agent streaming and tool result capture (agent_registry.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 3 — Router Refactor (orchestrator.py)
|
|
||||||
|
|
||||||
**Goal**: Orchestrator returns agent name alongside execution, supports streaming.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/core/orchestrator.py`:
|
|
||||||
- Add `orchestrate_v3(user_id, message, context, mode)` that:
|
|
||||||
1. Calls `classify_intent()` (unchanged) -> `agent_name`
|
|
||||||
2. Instantiates agent via registry
|
|
||||||
3. Returns `(agent_name, agent_instance)` — caller drives execution
|
|
||||||
- Add `orchestrate_v3_stream(user_id, message, context)` -> `AsyncGenerator` that:
|
|
||||||
1. Calls `classify_intent()` -> `agent_name`
|
|
||||||
2. Calls `agent.handle_stream()` (uses `_tool_loop_stream`)
|
|
||||||
3. Yields `(agent_name, token)` tuples — first yield includes agent name for domain detection
|
|
||||||
- Keep `orchestrate()` and `orchestrate_stream()` unchanged (backward compat for POST /chat).
|
|
||||||
|
|
||||||
**Files touched**: `app/core/orchestrator.py`
|
|
||||||
|
|
||||||
**Test**: Unit test with mocked LLM and mocked registry that verifies `orchestrate_v3_stream` yields `(agent_name, token)` pairs.
|
|
||||||
```
|
|
||||||
pytest tests/test_orchestrator_v3.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 3 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-3: add router refactor with streaming support (orchestrator.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 4 — Output Formatting Layer (NEW: output_formatter.py)
|
|
||||||
|
|
||||||
**Goal**: Home and Floating responses diverge at this layer only.
|
|
||||||
|
|
||||||
### Block Types (from Electron app components)
|
|
||||||
|
|
||||||
The LLM outputs a JSON block stream. Each block has a `type` field that maps to
|
|
||||||
an Electron renderer component. The server validates and forwards these blocks.
|
|
||||||
|
|
||||||
**Text block** — streamed immediately, word-by-word:
|
|
||||||
```json
|
|
||||||
{ "type": "text", "content": "Here's your task summary..." }
|
|
||||||
```
|
|
||||||
|
|
||||||
**Chart blocks** — buffered until complete, validated, sent as `stream_block`.
|
|
||||||
Chart types match shadcn/ui Recharts wrappers used in the Electron app:
|
|
||||||
```json
|
|
||||||
{ "type": "chart", "chartType": "<type>", "title": "...", "data": [...], "config": {...} }
|
|
||||||
```
|
|
||||||
Supported `chartType` values:
|
|
||||||
- `area` — Area chart (shadcn AreaChart)
|
|
||||||
- `bar` — Bar chart (shadcn BarChart)
|
|
||||||
- `line` — Line chart (shadcn LineChart)
|
|
||||||
- `pie` — Pie chart (shadcn PieChart)
|
|
||||||
- `radar` — Radar chart (shadcn RadarChart)
|
|
||||||
- `radial` — Radial/gauge chart (shadcn RadialChart)
|
|
||||||
|
|
||||||
`data` is an array of objects with keys matching the chart's dataKey config.
|
|
||||||
`config` follows the shadcn ChartConfig format: `{ [dataKey]: { label, color } }`.
|
|
||||||
|
|
||||||
**Entity blocks** — server serializes from `agent.tool_results` (not LLM-generated data):
|
|
||||||
```json
|
|
||||||
{ "type": "entity_ref", "entity": "task" }
|
|
||||||
```
|
|
||||||
The server resolves this by looking up the structured data from the agent's
|
|
||||||
tool call results and emitting a `stream_block` with the full entity data.
|
|
||||||
|
|
||||||
Supported entity types (matching Electron component types):
|
|
||||||
- `task` — TaskRow component (`TaskItem`: id, title, status, priority, assignee, dueDate, projectId, ...)
|
|
||||||
- `project` — Project card (id, name, clientId, status)
|
|
||||||
- `note` — Note card (id, title, createdAt, projectId)
|
|
||||||
- `timeline` — Timeline card (GanttTimeline: id, title, date, projectId, isAiSuggested, isApproved)
|
|
||||||
|
|
||||||
**Table block** — buffered, validated:
|
|
||||||
```json
|
|
||||||
{ "type": "table", "headers": ["Col1", "Col2"], "rows": [["val1", "val2"]] }
|
|
||||||
```
|
|
||||||
|
|
||||||
**Timeline block** — buffered, validated (renders via GanttChart component):
|
|
||||||
```json
|
|
||||||
{ "type": "timeline", "timelines": [{ "id": "...", "title": "...", "date": 1234567890 }] }
|
|
||||||
```
|
|
||||||
|
|
||||||
### Changes
|
|
||||||
|
|
||||||
- `app/core/output_formatter.py` (new file):
|
|
||||||
- `HomeFormatter`:
|
|
||||||
- Receives token stream from orchestrator
|
|
||||||
- Accumulates tokens into a JSON-aware buffer
|
|
||||||
- Detects block boundaries by `type` field:
|
|
||||||
- `text` -> yields `WsStreamText` immediately (streams content word-by-word)
|
|
||||||
- `chart` -> buffers until JSON complete, validates `chartType` against allowed set, yields `WsStreamBlock`
|
|
||||||
- `entity_ref` -> looks up data from `agent.tool_results`, serializes full entity, yields `WsStreamBlock`
|
|
||||||
- `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock`
|
|
||||||
- `timeline` -> buffers, validates timeline objects, yields `WsStreamBlock`
|
|
||||||
- Invalid blocks are logged and skipped (never crash the stream)
|
|
||||||
- `FloatingFormatter`:
|
|
||||||
- Receives `agent_name` from orchestrator
|
|
||||||
- Maps agent name to domain (deterministic, by code — no LLM):
|
|
||||||
- `task_agent` -> `"tasks"`
|
|
||||||
- `timeline_agent` -> `"timelines"`
|
|
||||||
- `note_agent` -> `"notes"`
|
|
||||||
- `project_agent` -> `"projects"`
|
|
||||||
- Yields `WsFloatingDomain` immediately
|
|
||||||
- Then yields `WsStreamText` for all tokens (text-only, no blocks)
|
|
||||||
|
|
||||||
**Files touched**: `app/core/output_formatter.py` (new)
|
|
||||||
|
|
||||||
**Test**: Unit test that feeds a mock token stream through each formatter and asserts correct frame output sequence.
|
|
||||||
```
|
|
||||||
pytest tests/test_output_formatter.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 4 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-4: add output formatting layer (output_formatter.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py)
|
|
||||||
|
|
||||||
**Goal**: Single multiplexed WebSocket handles device frames + Home/Floating chat.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/api/routes/device_ws.py`:
|
|
||||||
- Extend `_message_loop` dispatch to handle `home_request` and `floating_request`:
|
|
||||||
- On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket.
|
|
||||||
- On `floating_request`: same, but pipe through `FloatingFormatter`.
|
|
||||||
- Wrap both in try/finally to clear `ws_context`.
|
|
||||||
- Each request gets a `request_id` (UUID) for frame correlation.
|
|
||||||
- Concurrent requests from same client are supported (each runs as an async task).
|
|
||||||
- `app/api/routes/chat.py`:
|
|
||||||
- Remove `chat_stream` WS endpoint and any related helper functions that were only used by it.
|
|
||||||
- Keep `POST /chat` endpoint unchanged (REST fallback).
|
|
||||||
- Clean up any unused imports.
|
|
||||||
- `app/main.py`:
|
|
||||||
- No change needed (device_ws router already registered).
|
|
||||||
|
|
||||||
**Files touched**: `app/api/routes/device_ws.py`, `app/api/routes/chat.py`, `app/main.py`
|
|
||||||
|
|
||||||
**Test**: Integration test with a WebSocket test client that:
|
|
||||||
1. Connects to `/api/v1/ws/device`
|
|
||||||
2. Sends `device_hello`
|
|
||||||
3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end`
|
|
||||||
4. Sends `floating_request` -> receives `floating_domain`, `stream_text`*, `stream_end`
|
|
||||||
5. Verifies `tool_call`/`tool_result` round-trip still works during chat
|
|
||||||
```
|
|
||||||
pytest tests/test_ws_unified.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 5 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-5: unify ws handler (device_ws.py, chat.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 6 — Memory Models + Migration (models.py, alembic)
|
|
||||||
|
|
||||||
**Goal**: Database tables for 4-tier memory, with per-user encryption key.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/models.py`:
|
|
||||||
- Add `encryption_key` column to `User` model (Fernet key, generated on registration).
|
|
||||||
- Add `MemoryCore` model: `id, user_id, key, value_encrypted, updated_at`
|
|
||||||
- Add `MemoryAssociative` model: `id, user_id, content_encrypted, embedding (Vector(1536)), entity_type, entity_id, updated_at`
|
|
||||||
- Add `MemoryEpisodic` model: `id, user_id, summary_encrypted, session_id, created_at`
|
|
||||||
- Add `MemoryProactive` model: `id, user_id, pattern_encrypted, confidence, source, created_at`
|
|
||||||
- `alembic/versions/` — New migration adding the 4 memory tables + user encryption_key column.
|
|
||||||
- `app/api/routes/auth.py` — On user registration, generate and store a Fernet key.
|
|
||||||
|
|
||||||
**Files touched**: `app/models.py`, `alembic/versions/xxx_add_memory_tables.py`, `app/api/routes/auth.py`
|
|
||||||
|
|
||||||
**Test**: Run migration up/down, verify tables exist with correct columns.
|
|
||||||
```
|
|
||||||
alembic upgrade head && alembic downgrade -1 && alembic upgrade head
|
|
||||||
pytest tests/test_memory_models.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 6 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-6: add memory models and migration (models.py, alembic)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 7 — Memory Middleware (NEW: memory_middleware.py)
|
|
||||||
|
|
||||||
**Goal**: Enrich every Router call with memory context, store interactions after.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/core/memory_middleware.py` (new file):
|
|
||||||
- `MemoryMiddleware` class with:
|
|
||||||
- `enrich_context(user_id, message) -> dict` (pre-LLM):
|
|
||||||
1. Load core memory (user prefs) — always injected
|
|
||||||
2. Embed `message`, search `MemoryAssociative` via pgvector — top-k relevant
|
|
||||||
3. Fetch recent `MemoryEpisodic` entries — last N sessions
|
|
||||||
4. Fetch active `MemoryProactive` patterns — above confidence threshold
|
|
||||||
5. Return merged context dict
|
|
||||||
- `store_episode(user_id, session_id, message, response)` (post-LLM):
|
|
||||||
1. Summarize interaction (short LLM call or heuristic)
|
|
||||||
2. Encrypt and store in `MemoryEpisodic`
|
|
||||||
3. Embed interaction, encrypt and upsert in `MemoryAssociative`
|
|
||||||
- `update_core(user_id, key, value)` — explicit preference update
|
|
||||||
- All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key`
|
|
||||||
- `app/api/routes/device_ws.py` — Update `home_request` and `floating_request` handlers:
|
|
||||||
- Before orchestrator: `enriched = await memory.enrich_context(user_id, message)`
|
|
||||||
- After response complete: `await memory.store_episode(user_id, ...)`
|
|
||||||
|
|
||||||
**Files touched**: `app/core/memory_middleware.py` (new), `app/api/routes/device_ws.py`
|
|
||||||
|
|
||||||
**Test**: Unit test with seeded memory rows that verifies:
|
|
||||||
1. `enrich_context` returns core prefs + associative matches + episodic summaries
|
|
||||||
2. `store_episode` creates encrypted rows that can be decrypted with the user's key
|
|
||||||
3. End-to-end WS test: send `home_request`, verify memory enrichment is passed to orchestrator
|
|
||||||
```
|
|
||||||
pytest tests/test_memory_middleware.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 7 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-7: add memory middleware (memory_middleware.py, device_ws.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
| Step | Component | Effort | Depends On |
|
|
||||||
|------|-----------|--------|------------|
|
|
||||||
| 1 | WS Frame Protocol | Low | — |
|
|
||||||
| 2 | Agent Streaming | Medium | Step 1 |
|
|
||||||
| 3 | Router Refactor | Medium | Step 2 |
|
|
||||||
| 4 | Output Formatter | High | Steps 1, 3 |
|
|
||||||
| 5 | Unified WS Handler | High | Steps 1–4 |
|
|
||||||
| 6 | Memory Models | Medium | — |
|
|
||||||
| 7 | Memory Middleware | High | Steps 5, 6 |
|
|
||||||
|
|
||||||
Steps 1–5 form the streaming pipeline. Steps 6–7 form the memory system.
|
|
||||||
Step 6 can run in parallel with Steps 2–4 (no dependencies).
|
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Deprecate backend agent config tables.
|
||||||
|
|
||||||
|
The Electron client is now the source of truth for agent configuration
|
||||||
|
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
||||||
|
billing checks and trigger/run logs only.
|
||||||
|
|
||||||
|
Revision ID: 9a1f2d0b6c7e
|
||||||
|
Revises: 818478c251dc
|
||||||
|
Create Date: 2026-03-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "9a1f2d0b6c7e"
|
||||||
|
down_revision: Union[str, None] = "818478c251dc"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = sa.inspect(bind)
|
||||||
|
existing = set(inspector.get_table_names())
|
||||||
|
|
||||||
|
if "cloud_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
|
||||||
|
if "local_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Import all agent modules to trigger @registry.register decorators."""
|
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
||||||
|
|
||||||
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
85
app/agents/filesystem_agent.py
Normal file
85
app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||||
|
|
||||||
|
These tools delegate to the Electron client via ``execute_on_client()`` using
|
||||||
|
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
||||||
|
handles actual disk I/O and responds with ``tool_result`` frames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str:
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{path}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str:
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{path}' is empty or could not be read."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str:
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", path)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FILESYSTEM_TOOLS: list[Any] = [
|
||||||
|
list_directory,
|
||||||
|
read_file_content,
|
||||||
|
get_file_metadata,
|
||||||
|
]
|
||||||
@@ -2,17 +2,23 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
from app.core.llm import embed
|
||||||
from app.core.llm import embed, get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
NOTE_SYSTEM_PROMPT = (
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
"and delete Markdown notes in their workspace.\n\n"
|
||||||
"Rules:\n"
|
"Rules:\n"
|
||||||
@@ -22,6 +28,7 @@ _SYSTEM_PROMPT = (
|
|||||||
" before appending or replacing sections\n"
|
" before appending or replacing sections\n"
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||||
" when the user is working within a specific project\n"
|
" when the user is working within a specific project\n"
|
||||||
|
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||||
" is already in the note (retrieved via get_note)."
|
" is already in the note (retrieved via get_note)."
|
||||||
)
|
)
|
||||||
@@ -30,10 +37,11 @@ _SYSTEM_PROMPT = (
|
|||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -122,23 +130,10 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
NOTE_TOOLS: list[Any] = [
|
||||||
class NoteAgent(ChatAgent):
|
list_notes,
|
||||||
def get_name(self) -> str:
|
get_note,
|
||||||
return "note_agent"
|
create_note,
|
||||||
|
update_note,
|
||||||
def get_description(self) -> str:
|
delete_note,
|
||||||
return "Manages notes: list, get, create, update, delete"
|
]
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_notes, get_note, create_note, update_note, delete_note]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -2,17 +2,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
PROJECT_SYSTEM_PROMPT = (
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
"You are a project management assistant. You help users create, find,\n"
|
||||||
"update, and archive projects in their workspace.\n\n"
|
"update, and archive projects in their workspace.\n\n"
|
||||||
"Rules:\n"
|
"Rules:\n"
|
||||||
@@ -137,30 +133,11 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
PROJECT_TOOLS: list[Any] = [
|
||||||
class ProjectAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "project_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [
|
|
||||||
list_projects,
|
list_projects,
|
||||||
list_all_projects,
|
list_all_projects,
|
||||||
get_project,
|
get_project,
|
||||||
create_project,
|
create_project,
|
||||||
update_project,
|
update_project,
|
||||||
delete_project,
|
delete_project,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -2,18 +2,23 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TASK_SYSTEM_PROMPT = (
|
||||||
"You are a task management assistant for a project workspace.\n"
|
"You are a task management assistant for a project workspace.\n"
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
"You create, update, list, and track tasks and their comments.\n\n"
|
||||||
"Rules:\n"
|
"Rules:\n"
|
||||||
@@ -43,11 +48,12 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={
|
filters={
|
||||||
"projectId": project_id or None,
|
"projectId": normalized_project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
@@ -209,8 +215,12 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result.get("row", {})
|
||||||
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
row_author = row.get("author", author)
|
||||||
|
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
||||||
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
|
row_comment_id = row.get("id", "unknown")
|
||||||
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -223,16 +233,7 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
# ── Agent ─────────────────────────────────────────────────────────────
|
# ── Agent ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
TASK_TOOLS: list[Any] = [
|
||||||
class TaskAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "task_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [
|
|
||||||
list_tasks,
|
list_tasks,
|
||||||
create_task,
|
create_task,
|
||||||
update_task,
|
update_task,
|
||||||
@@ -241,14 +242,4 @@ class TaskAgent(ChatAgent):
|
|||||||
list_task_comments,
|
list_task_comments,
|
||||||
add_task_comment,
|
add_task_comment,
|
||||||
delete_task_comment,
|
delete_task_comment,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -2,21 +2,27 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TIMELINE_SYSTEM_PROMPT = (
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
"track progress on a project — they are not calendar events.\n\n"
|
||||||
"Rules:\n"
|
"Rules:\n"
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||||
|
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||||
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
||||||
@@ -29,10 +35,11 @@ _SYSTEM_PROMPT = (
|
|||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -106,23 +113,9 @@ async def delete_timeline(timeline_id: str) -> str:
|
|||||||
return f"Timeline {timeline_id} deleted."
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
TIMELINE_TOOLS: list[Any] = [
|
||||||
class TimelineAgent(ChatAgent):
|
list_timelines,
|
||||||
def get_name(self) -> str:
|
create_timeline,
|
||||||
return "timeline_agent"
|
update_timeline,
|
||||||
|
delete_timeline,
|
||||||
def get_description(self) -> str:
|
]
|
||||||
return "Manages project timelines (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_timelines, create_timeline, update_timeline, delete_timeline]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -55,12 +55,15 @@ async def get_current_user(
|
|||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup — subscription row is the authoritative source.
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
||||||
|
# block local development when no Stripe subscription exists.
|
||||||
from app.models import Subscription, User # noqa: PLC0415
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str = result.scalar_one_or_none() or "free"
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
# Fetch name/surname from user row.
|
# Fetch name/surname from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
|
|||||||
@@ -1,54 +1,40 @@
|
|||||||
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
Endpoints:
|
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||||
POST /agents/journey/start — start a new journey session
|
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||||
POST /agents/journey/message — continue the conversation
|
frames to the functions exported here.
|
||||||
|
|
||||||
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
|
||||||
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
1. FE sends ``journey_start`` frame with basic agent config (directory,
|
||||||
2. Server creates a session, calls the LLM with a contextual system prompt,
|
data_types, schedule).
|
||||||
and returns the first question.
|
2. Server creates an in-memory session, sets up a WS executor so the
|
||||||
3. Client sends follow-up messages to ``/message``.
|
setup LLM can use file-system tools, does a first directory scrape,
|
||||||
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
and sends back a ``journey_reply`` with the first question.
|
||||||
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
3. FE sends ``journey_message`` frames for each user reply.
|
||||||
5. Server parses the block, sets ``done=True``, and returns the template.
|
4. Server appends the user message, calls the LLM (which may read files
|
||||||
|
via tools), and sends back a ``journey_reply``.
|
||||||
The ``prompt_template`` from the final response is meant to be stored in
|
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
|
||||||
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||||
by the Electron client (via the agent CRUD endpoints).
|
6. Server parses the block, sends ``journey_reply`` with ``done=True``
|
||||||
|
and the template. FE stores it locally.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_llm
|
||||||
from app.db import get_session
|
|
||||||
from app.models import CloudAgentConfig, LocalAgentConfig
|
|
||||||
from app.schemas import (
|
|
||||||
JourneyMessageRequest,
|
|
||||||
JourneyResponse,
|
|
||||||
JourneyStartRequest,
|
|
||||||
UserProfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
|
||||||
|
|
||||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
@@ -59,16 +45,21 @@ _TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
|||||||
|
|
||||||
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
||||||
_MAX_TURNS: int = 5
|
_MAX_TURNS: int = 5
|
||||||
|
# Max tool-calling steps per LLM invocation.
|
||||||
|
_MAX_TOOL_STEPS: int = 6
|
||||||
|
|
||||||
# ── In-memory session store ───────────────────────────────────────────────
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _JourneySession:
|
class JourneySession:
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
agent_type: str # "local" | "cloud"
|
agent_type: str # "local" | "cloud"
|
||||||
|
directory: str
|
||||||
|
data_types: list[str]
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
system_prompt: str = ""
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
@@ -76,67 +67,77 @@ class _JourneySession:
|
|||||||
|
|
||||||
|
|
||||||
# session_id → session
|
# session_id → session
|
||||||
_sessions: dict[str, _JourneySession] = {}
|
_sessions: dict[str, JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||||
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||||
s = _sessions.get(session_id)
|
s = _sessions.get(session_id)
|
||||||
if s is None or s.is_expired():
|
if s is None or s.is_expired():
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
if s.user_id != user_id:
|
if s.user_id != user_id:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt builder ─────────────────────────────────────────────────
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
_LOCAL_PREAMBLE = """\
|
|
||||||
What kind of files are in the directories you want to monitor? \
|
|
||||||
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
|
||||||
|
|
||||||
_CLOUD_PREAMBLE = """\
|
|
||||||
What kind of emails or messages should I look for? \
|
|
||||||
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
Your job is to understand exactly what data the user wants to extract from their
|
||||||
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
local directory and produce a detailed prompt_template that a separate AI will use
|
||||||
|
as its instruction set.
|
||||||
|
|
||||||
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
The extraction agent already has this base behaviour built in:
|
||||||
1. The type and format of the source content.
|
- Reads each file using file-system tools.
|
||||||
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
|
||||||
3. How fields should be mapped (e.g. email subject → task title).
|
- Sets isAiSuggested=1 and isApproved=0 on every record.
|
||||||
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
- Only extracts data explicitly present in the files — it never invents information.
|
||||||
5. Any special handling, date extraction, or exclusions.
|
The user's custom prompt is appended AFTER this base behaviour, so focus on
|
||||||
|
what to look for and how to map it — not on the general extraction mechanics.
|
||||||
|
|
||||||
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
You have access to file-system tools to explore the user's directory:
|
||||||
these exact markers on their own lines:
|
- list_directory: to see folder structure
|
||||||
|
- read_file_content: to peek at file contents
|
||||||
|
- get_file_metadata: to check file info
|
||||||
|
|
||||||
|
The user's configured directory is: {directory}
|
||||||
|
Target data types: {data_types}
|
||||||
|
|
||||||
|
Start by exploring the directory to understand its structure. Then ask concise,
|
||||||
|
focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
|
1. The type and format of the source content (confirmed by your exploration).
|
||||||
|
2. How fields should be mapped (e.g. filename → task title).
|
||||||
|
3. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
4. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
|
After 3-5 questions (when you have enough information), output the final prompt_template
|
||||||
|
between these exact markers on their own lines:
|
||||||
|
|
||||||
{template_start}
|
{template_start}
|
||||||
<the complete extraction prompt here>
|
<the complete extraction prompt here>
|
||||||
{template_end}
|
{template_end}
|
||||||
|
|
||||||
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
The prompt_template must be a self-contained instruction for an AI that reads files
|
||||||
and must return a JSON array of records in this shape:
|
and must perform CRUD operations using tools to create records. It should specify:
|
||||||
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
- What entity types to create (tasks, notes, timelines, projects).
|
||||||
|
- How to map file content to record fields (camelCase: title, status, priority,
|
||||||
|
dueDate, projectId, content, etc.).
|
||||||
|
- That isAiSuggested must be set to 1 and isApproved to 0 on every record.
|
||||||
|
- Concrete examples of mappings based on what you discovered in the directory.
|
||||||
|
|
||||||
Rules for the generated template:
|
|
||||||
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
|
||||||
- Include concrete examples of mappings.
|
|
||||||
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
|
||||||
- Set isAiSuggested: true and isApproved: false on every record.
|
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
Do not ask more than {max_turns} questions total. Begin by exploring the directory,
|
||||||
|
then ask your first question.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
def _build_system_prompt(
|
||||||
source_description = (
|
directory: str,
|
||||||
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
data_types: list[str],
|
||||||
)
|
existing_template: str | None = None,
|
||||||
|
) -> str:
|
||||||
existing_section = (
|
existing_section = (
|
||||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
f"---\n{existing_template}\n---\n"
|
f"---\n{existing_template}\n---\n"
|
||||||
@@ -144,7 +145,8 @@ def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
|||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
return _SYSTEM_PROMPT_TEMPLATE.format(
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
source_description=source_description,
|
directory=directory,
|
||||||
|
data_types=", ".join(data_types),
|
||||||
template_start=_TEMPLATE_START,
|
template_start=_TEMPLATE_START,
|
||||||
template_end=_TEMPLATE_END,
|
template_end=_TEMPLATE_END,
|
||||||
existing_section=existing_section,
|
existing_section=existing_section,
|
||||||
@@ -152,10 +154,6 @@ def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _first_question(agent_type: str) -> str:
|
|
||||||
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
|
||||||
|
|
||||||
|
|
||||||
# ── Template extraction ───────────────────────────────────────────────────
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -168,11 +166,37 @@ def _extract_template(text: str) -> str | None:
|
|||||||
return text[start_idx:end_idx].strip() or None
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call ─────────────────────────────────────────────────────────────
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
def _as_text(content: Any) -> str:
|
||||||
"""Build LangChain messages from history and invoke the LLM."""
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm_with_tools(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tools: list[Any],
|
||||||
|
) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
|
continue until a final text response is produced.
|
||||||
|
"""
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
for turn in history:
|
for turn in history:
|
||||||
if turn["role"] == "user":
|
if turn["role"] == "user":
|
||||||
@@ -181,137 +205,194 @@ async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
|||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_llm(model=None, temperature=0.4)
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
response = await llm.ainvoke(messages)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
return response.content # type: ignore[return-value]
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(_MAX_TOOL_STEPS):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:800],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max steps.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
# ── Existing-config loader ────────────────────────────────────────────────
|
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _load_existing_template(
|
async def handle_journey_start(
|
||||||
agent_id: str,
|
|
||||||
user_id: str,
|
user_id: str,
|
||||||
db: AsyncSession,
|
frame: dict[str, Any],
|
||||||
) -> str | None:
|
) -> dict[str, Any]:
|
||||||
"""Return the prompt_template of an existing agent config, or None."""
|
"""Handle a ``journey_start`` WS frame.
|
||||||
# Try local first, then cloud.
|
|
||||||
local_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
local = local_result.scalar_one_or_none()
|
|
||||||
if local is not None:
|
|
||||||
return local.prompt_template
|
|
||||||
|
|
||||||
cloud_result = await db.execute(
|
Creates a session, runs the setup LLM with directory exploration,
|
||||||
select(CloudAgentConfig).where(
|
and returns the ``journey_reply`` payload.
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud = cloud_result.scalar_one_or_none()
|
|
||||||
return cloud.prompt_template if cloud is not None else None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def start_journey(
|
|
||||||
body: JourneyStartRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Start a new Chatbot Journey session.
|
|
||||||
|
|
||||||
If ``agent_id`` is provided the session is pre-seeded with the existing
|
|
||||||
agent's ``prompt_template`` so the user can refine it.
|
|
||||||
"""
|
"""
|
||||||
# Load existing template (may be None).
|
agent_type = frame.get("agent_type", "local")
|
||||||
existing_template: str | None = None
|
directory = frame.get("directory", "")
|
||||||
if body.agent_id:
|
data_types = frame.get("data_types", [])
|
||||||
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
existing_template = frame.get("existing_template")
|
||||||
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
|
||||||
# the user may be starting a fresh journey for a not-yet-persisted config).
|
|
||||||
|
|
||||||
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
# Use the session_id provided by the FE so the reply matches the
|
||||||
first_question = _first_question(body.agent_type)
|
# listener key; fall back to a generated one if absent.
|
||||||
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
|
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||||
|
|
||||||
session_id = str(uuid.uuid4())
|
session = JourneySession(
|
||||||
session = _JourneySession(
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=current_user.id,
|
user_id=user_id,
|
||||||
agent_type=body.agent_type,
|
agent_type=agent_type,
|
||||||
# Seed history with the AI's first question so it stays consistent.
|
directory=directory,
|
||||||
history=[{"role": "assistant", "content": first_question}],
|
data_types=data_types,
|
||||||
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
# Store the system prompt inside the session for reuse in /message.
|
|
||||||
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
|
||||||
|
# ws_context executor (already set by the WS handler before calling us).
|
||||||
|
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
|
||||||
|
# require at least one user/input message to be present.
|
||||||
|
seed_history: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
|
]
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history=seed_history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.extend(seed_history)
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
_sessions[session_id] = session
|
_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
logger.info(
|
||||||
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
"agent_setup: journey session %s started for user %s (directory=%s)",
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the LLM produced the template on the first turn (unlikely but possible).
|
||||||
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def send_journey_message(
|
|
||||||
body: JourneyMessageRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Send a message in an existing Chatbot Journey session.
|
|
||||||
|
|
||||||
The server appends the user's message to the conversation history,
|
|
||||||
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
|
||||||
``prompt_template`` block the response includes ``done=True`` and the
|
|
||||||
extracted template.
|
|
||||||
"""
|
|
||||||
session = _get_session(body.session_id, current_user.id)
|
|
||||||
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Append user turn to history.
|
|
||||||
session.history.append({"role": "user", "content": body.message})
|
|
||||||
|
|
||||||
# Call the LLM with the full conversation so far.
|
|
||||||
ai_reply = await _call_llm(system_prompt, session.history)
|
|
||||||
|
|
||||||
# Append AI turn.
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
|
|
||||||
# Check if the LLM produced the final template.
|
|
||||||
prompt_template = _extract_template(ai_reply)
|
prompt_template = _extract_template(ai_reply)
|
||||||
done = prompt_template is not None
|
done = prompt_template is not None
|
||||||
|
|
||||||
# Strip the sentinel markers from the message shown to the user.
|
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
)
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
if done:
|
return {
|
||||||
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
"type": "journey_reply",
|
||||||
# Clean up the session immediately on completion.
|
"session_id": session_id,
|
||||||
_sessions.pop(body.session_id, None)
|
"message": display_message,
|
||||||
else:
|
"done": done,
|
||||||
# Nudge the LLM to wrap up after max turns.
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_message(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_message`` WS frame.
|
||||||
|
|
||||||
|
Appends the user message, calls the LLM, and returns the
|
||||||
|
``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
message = frame.get("message", "")
|
||||||
|
|
||||||
|
session = get_journey_session(session_id, user_id)
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Append user turn.
|
||||||
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
# Call the LLM with tools.
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
# Check if the LLM produced the final template.
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
# If the LLM didn't produce a template but we've hit max turns, nudge it
|
||||||
|
# and call the LLM one more time to force template generation.
|
||||||
|
if not done:
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
if turns >= _MAX_TURNS:
|
if turns >= _MAX_TURNS:
|
||||||
# Add a system-level nudge as a hidden user message.
|
nudge_content = (
|
||||||
session.history.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": (
|
|
||||||
"[System: You have enough information. Please generate the final "
|
"[System: You have enough information. Please generate the final "
|
||||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
),
|
|
||||||
})
|
|
||||||
|
|
||||||
return JourneyResponse(
|
|
||||||
session_id=body.session_id,
|
|
||||||
message=display_message,
|
|
||||||
done=done,
|
|
||||||
prompt_template=prompt_template,
|
|
||||||
)
|
)
|
||||||
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
|
nudge_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(nudge_reply)
|
||||||
|
if prompt_template is not None:
|
||||||
|
done = True
|
||||||
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
if _TEMPLATE_START in ai_reply
|
||||||
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,45 +1,36 @@
|
|||||||
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
"""Agent routes.
|
||||||
|
|
||||||
Endpoints:
|
Backend responsibilities are intentionally minimal:
|
||||||
GET /agents/catalog — hardcoded agent type catalog
|
GET /agents/catalog — static catalog for UI display
|
||||||
GET /agents/local — list user's local agent configs
|
POST /agents/can-create — billing eligibility check
|
||||||
POST /agents/local — create local agent (tier-gated)
|
POST /agents/trigger — trigger a local agent run
|
||||||
PUT /agents/local/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
Agent configuration is owned by the Electron app and is not persisted
|
||||||
GET /agents/cloud — list user's cloud agent configs
|
in backend agent-config tables.
|
||||||
POST /agents/cloud — create cloud agent (tier-gated)
|
|
||||||
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
|
||||||
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
|
||||||
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
import uuid
|
||||||
from typing import Any
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from pydantic import BaseModel
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy import func, or_, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
from app.core.agent_runner import run_local_agent
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
AgentCatalogItem,
|
AgentCatalogItem,
|
||||||
|
AgentCreationCheckRequest,
|
||||||
|
AgentCreationCheckResponse,
|
||||||
AgentRunLogResponse,
|
AgentRunLogResponse,
|
||||||
CloudAgentConfigCreate,
|
AgentTriggerRequest,
|
||||||
CloudAgentConfigResponse,
|
|
||||||
CloudAgentConfigUpdate,
|
|
||||||
LocalAgentConfigCreate,
|
|
||||||
LocalAgentConfigResponse,
|
|
||||||
LocalAgentConfigUpdate,
|
|
||||||
UserProfile,
|
UserProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -56,39 +47,21 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
# ── Model → schema converters ─────────────────────────────────────────
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
"task": "tasks", "tasks": "tasks",
|
||||||
return LocalAgentConfigResponse(
|
"note": "notes", "notes": "notes",
|
||||||
id=a.id,
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
name=a.name,
|
"project": "projects", "projects": "projects",
|
||||||
device_id=a.device_id,
|
}
|
||||||
directory_paths=a.directory_paths,
|
seen: set[str] = set()
|
||||||
data_types=a.data_types,
|
result: list[str] = []
|
||||||
prompt_template=a.prompt_template,
|
for v in values:
|
||||||
file_extensions=a.file_extensions,
|
mapped = normalize.get(v)
|
||||||
schedule_cron=a.schedule_cron,
|
if mapped and mapped not in seen:
|
||||||
enabled=a.enabled,
|
seen.add(mapped)
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
result.append(mapped)
|
||||||
created_at=_dt_ms(a.created_at),
|
return result
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
|
||||||
return CloudAgentConfigResponse(
|
|
||||||
id=a.id,
|
|
||||||
provider=a.provider, # type: ignore[arg-type]
|
|
||||||
name=a.name,
|
|
||||||
data_types=a.data_types,
|
|
||||||
prompt_template=a.prompt_template,
|
|
||||||
schedule_cron=a.schedule_cron,
|
|
||||||
filter_config=a.filter_config,
|
|
||||||
enabled=a.enabled,
|
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
|
||||||
created_at=_dt_ms(a.created_at),
|
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
@@ -105,77 +78,42 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Ownership-checked lookups ─────────────────────────────────────────
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
|
||||||
async def _get_local_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> LocalAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_cloud_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> CloudAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier limit helper ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
|
||||||
"""Return combined enabled local + cloud agent count for the user."""
|
|
||||||
local_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(LocalAgentConfig.id)).where(
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
LocalAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
cloud_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(CloudAgentConfig.id)).where(
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
CloudAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
return local_count + cloud_count
|
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
if limit != -1 and current_count >= limit:
|
if limit != -1 and current_count >= limit:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
)
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
async def _enforce_run_frequency(
|
||||||
|
tier: str,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return # unlimited
|
||||||
|
|
||||||
class _RunsPage(BaseModel):
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
total: int
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
page: int
|
)
|
||||||
limit: int
|
result = await db.execute(
|
||||||
items: list[AgentRunLogResponse]
|
select(func.count(AgentRunLog.id)).where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
@@ -209,229 +147,55 @@ async def get_agent_catalog(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent CRUD ──────────────────────────────────────────────────
|
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
||||||
|
async def can_create_agent(
|
||||||
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
body: AgentCreationCheckRequest,
|
||||||
async def list_local_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
) -> AgentCreationCheckResponse:
|
||||||
) -> list[LocalAgentConfigResponse]:
|
"""Check if the user can create one more agent based on billing tier.
|
||||||
"""List all local directory agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_local_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
Since configuration is client-owned, the Electron app sends its current
|
||||||
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
active agent count and the backend applies tier limits.
|
||||||
async def create_local_agent(
|
|
||||||
body: LocalAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Create a new local directory agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
"""
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||||
agent = LocalAgentConfig(
|
allowed = limit == -1 or body.active_agents < limit
|
||||||
user_id=current_user.id,
|
return AgentCreationCheckResponse(
|
||||||
name=body.name,
|
allowed=allowed,
|
||||||
device_id=body.device_id,
|
tier=current_user.tier,
|
||||||
directory_paths=body.directory_paths,
|
active_agents=body.active_agents,
|
||||||
data_types=body.data_types,
|
limit=limit,
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
file_extensions=body.file_extensions,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
)
|
)
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
async def update_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: LocalAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Partially update a local agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/local/{agent_id}", response_model=dict)
|
|
||||||
async def delete_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
|
||||||
async def list_cloud_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[CloudAgentConfigResponse]:
|
|
||||||
"""List all cloud connector agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_cloud_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_cloud_agent(
|
|
||||||
body: CloudAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Create a new cloud connector agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
|
||||||
agent = CloudAgentConfig(
|
|
||||||
user_id=current_user.id,
|
|
||||||
provider=body.provider,
|
|
||||||
name=body.name,
|
|
||||||
data_types=body.data_types,
|
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
oauth_token_encrypted=body.oauth_token_encrypted,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
filter_config=body.filter_config,
|
|
||||||
)
|
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
|
||||||
async def update_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: CloudAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Partially update a cloud agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/cloud/{agent_id}", response_model=dict)
|
|
||||||
async def delete_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Run logs ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/runs", response_model=_RunsPage)
|
|
||||||
async def list_run_logs(
|
|
||||||
agent_id: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
limit: int = Query(default=20, ge=1, le=100),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> _RunsPage:
|
|
||||||
"""Return paginated run logs for the authenticated user.
|
|
||||||
|
|
||||||
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
|
||||||
"""
|
|
||||||
base_filter = [AgentRunLog.user_id == current_user.id]
|
|
||||||
if agent_id:
|
|
||||||
base_filter.append(AgentRunLog.agent_id == agent_id)
|
|
||||||
|
|
||||||
total = (
|
|
||||||
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
|
||||||
).scalar_one()
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(AgentRunLog)
|
|
||||||
.where(*base_filter)
|
|
||||||
.order_by(AgentRunLog.started_at.desc())
|
|
||||||
.offset((page - 1) * limit)
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
|
||||||
|
|
||||||
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Manual trigger stub ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
|
||||||
async def trigger_agent_run(
|
async def trigger_agent_run(
|
||||||
agent_id: str,
|
body: AgentTriggerRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> AgentRunLogResponse:
|
) -> AgentRunLogResponse:
|
||||||
"""Manually trigger an agent run.
|
"""Trigger a local agent run using client-provided configuration."""
|
||||||
|
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||||
|
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||||
|
|
||||||
Looks up the agent config (local or cloud) by ID with ownership check,
|
config = LocalAgentConfig(
|
||||||
creates a run log entry with ``status="running"``, and returns it.
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=current_user.id,
|
||||||
|
device_id=body.device_id,
|
||||||
|
name="Local Directory Monitor",
|
||||||
|
directory_paths=[body.directory],
|
||||||
|
data_types=_to_data_types(body.what_to_extract),
|
||||||
|
prompt_template=body.custom_agent_prompt,
|
||||||
|
file_extensions=[],
|
||||||
|
schedule_cron=body.batch_interval,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
Actual dispatch to the agent runner is wired in Step 3.4 once
|
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||||
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
stable_agent_id = body.agent_id or config.id
|
||||||
"""
|
|
||||||
# Determine agent type by trying local first, then cloud.
|
|
||||||
# Keep the full config object so we can pass it to the agent runner.
|
|
||||||
local_config: LocalAgentConfig | None = None
|
|
||||||
cloud_config: CloudAgentConfig | None = None
|
|
||||||
|
|
||||||
local_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
local_config = local_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if local_config is not None:
|
|
||||||
agent_type = "local"
|
|
||||||
else:
|
|
||||||
cloud_result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud_config = cloud_result.scalar_one_or_none()
|
|
||||||
if cloud_config is not None:
|
|
||||||
agent_type = "cloud"
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = AgentRunLog(
|
||||||
agent_id=agent_id,
|
agent_id=stable_agent_id,
|
||||||
agent_type=agent_type,
|
agent_type="local",
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
status="running",
|
status="running",
|
||||||
)
|
)
|
||||||
@@ -439,14 +203,14 @@ async def trigger_agent_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
# Dispatch the run as a background task — returns 202 immediately.
|
run_context = {
|
||||||
if agent_type == "local" and local_config is not None:
|
"type": "agent_batch",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
||||||
)
|
|
||||||
elif agent_type == "cloud" and cloud_config is not None:
|
|
||||||
asyncio.create_task(
|
|
||||||
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.core.orchestrator import orchestrate
|
from app.core.deep_agent import run_home
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.schemas import ChatRequest, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
@@ -20,10 +20,10 @@ async def chat(
|
|||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Route a chat message through the orchestrator.
|
"""REST fallback for home chat when websocket streaming is unavailable."""
|
||||||
|
response = await run_home(
|
||||||
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
user_id=current_user.id,
|
||||||
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
message=body.message,
|
||||||
"""
|
context=body.context.model_dump(),
|
||||||
result = await orchestrate(body)
|
)
|
||||||
return JSONResponse(content=result.model_dump())
|
return JSONResponse(content={"response": response})
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ Protocol:
|
|||||||
|
|
||||||
Incoming frame dispatch:
|
Incoming frame dispatch:
|
||||||
- ``tool_result`` → resolves a pending tool-call Future.
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
- ``agent_data`` → enqueued in the per-run agent data queue.
|
- ``journey_start`` → starts a guided setup journey session.
|
||||||
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
- ``journey_message`` → continues a journey conversation.
|
||||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
- unknown types → logged, ignored.
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
@@ -39,12 +39,13 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
|
from app.core.deep_agent import run_floating_stream, run_home_stream
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.orchestrator import orchestrate_v3_stream
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
@@ -147,37 +148,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: tool_result missing id from user=%s", user_id
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_data:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
await queue.put(frame)
|
|
||||||
except RuntimeError:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data for unknown run user=%s run=%s",
|
|
||||||
user_id,
|
|
||||||
run_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_complete:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
# Sentinel: signals the agent data stream is finished.
|
|
||||||
await queue.put(None)
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_complete missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.home_request:
|
elif frame_type == WsFrameType.home_request:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_home_request(websocket, user_id, frame)
|
_handle_home_request(websocket, user_id, frame)
|
||||||
@@ -188,6 +158,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_floating_request(websocket, user_id, frame)
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_start:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_start(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_message:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_message(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -219,33 +199,37 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
agent_holder: list = []
|
|
||||||
try:
|
try:
|
||||||
token_stream = orchestrate_v3_stream(
|
event_stream = run_home_stream(user_id, message, context)
|
||||||
user_id, message, context, agent_holder=agent_holder
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
)
|
async for ws_frame in formatter.format(event_stream):
|
||||||
formatter = HomeFormatter(request_id=request_id, tool_results=[])
|
|
||||||
async for ws_frame in formatter.format(token_stream):
|
|
||||||
# Inject mutations from agent tool_results into stream_end
|
|
||||||
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
|
|
||||||
ws_frame.mutations = [ # type: ignore[union-attr]
|
|
||||||
{"action": r["action"], "table": r["table"], "data": r["data"]}
|
|
||||||
for r in getattr(agent_holder[0], "tool_results", [])
|
|
||||||
]
|
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
# Collect text chunks to build the full response for episode storage
|
# Collect text chunks to build the full response for episode storage
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
@@ -262,7 +246,14 @@ async def _handle_home_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -276,29 +267,38 @@ async def _handle_floating_request(
|
|||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
scope: dict = frame.get("scope", {})
|
scope: dict = frame.get("scope", {})
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
json.dumps(scope, ensure_ascii=True)[:200],
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {"scope": scope, **memory_context}
|
context: dict = {
|
||||||
|
"scope": scope,
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
agent_holder: list = []
|
|
||||||
try:
|
try:
|
||||||
token_stream = orchestrate_v3_stream(
|
event_stream = run_floating_stream(user_id, message, context)
|
||||||
user_id, message, context, agent_holder=agent_holder
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
)
|
async for ws_frame in formatter.format(event_stream):
|
||||||
formatter = FloatingFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(token_stream):
|
|
||||||
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
|
|
||||||
ws_frame.mutations = [ # type: ignore[union-attr]
|
|
||||||
{"action": r["action"], "table": r["table"], "data": r["data"]}
|
|
||||||
for r in getattr(agent_holder[0], "tool_results", [])
|
|
||||||
]
|
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
@@ -314,8 +314,72 @@ async def _handle_floating_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_start(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_start frame — explores directory and sends first question."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_start(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": frame.get("session_id", ""),
|
||||||
|
"message": f"Failed to start journey: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_message(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_message frame — continues the journey conversation."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_message(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_message failed user=%s session=%s: %s",
|
||||||
|
user_id, session_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Journey error: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
@@ -351,6 +415,3 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
|||||||
user_id,
|
user_id,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.core.execution_plan import plan_cache
|
|
||||||
from app.schemas import ExecutionPlan, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/plans", tags=["plans"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook", response_model=list[ExecutionPlan])
|
|
||||||
async def list_playbooks(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached execution plan playbooks for the authenticated user.
|
|
||||||
|
|
||||||
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
|
||||||
"""
|
|
||||||
return plan_cache.get_all_playbooks()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
|
||||||
async def get_playbook(
|
|
||||||
plan_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> ExecutionPlan:
|
|
||||||
"""Return a specific execution plan playbook by ID."""
|
|
||||||
plan = plan_cache.get_plan(plan_id)
|
|
||||||
if plan is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Plan not found: {plan_id}",
|
|
||||||
)
|
|
||||||
return plan
|
|
||||||
@@ -21,6 +21,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"free": {
|
"free": {
|
||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
|
"batch_runs_per_day": 5,
|
||||||
"cloud_storage_gb": 0,
|
"cloud_storage_gb": 0,
|
||||||
"backup_gb": 0,
|
"backup_gb": 0,
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
@@ -31,6 +32,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
|
"batch_runs_per_day": 50,
|
||||||
"cloud_storage_gb": 5,
|
"cloud_storage_gb": 5,
|
||||||
"backup_gb": 5,
|
"backup_gb": 5,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -41,6 +43,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1, # unlimited
|
"batch_active": -1, # unlimited
|
||||||
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": 25,
|
"cloud_storage_gb": 25,
|
||||||
"backup_gb": 25,
|
"backup_gb": 25,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -51,6 +54,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": -1, # unlimited
|
"cloud_storage_gb": -1, # unlimited
|
||||||
"backup_gb": -1, # unlimited
|
"backup_gb": -1, # unlimited
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -77,16 +81,18 @@ class TierManager:
|
|||||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for ``user_id`` from the DB.
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
Falls back to ``'free'`` when no subscription row exists.
|
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
||||||
|
when no subscription row exists.
|
||||||
"""
|
"""
|
||||||
from app.models import Subscription # noqa: PLC0415
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str | None = result.scalar_one_or_none()
|
tier: str | None = result.scalar_one_or_none()
|
||||||
if tier is None or tier not in FEATURES:
|
if tier is None or tier not in FEATURES:
|
||||||
return "free"
|
return "power" if settings.ENV == "dev" else "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
# ── Feature access ───────────────────────────────────────────────────
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
"""Minimal agent base types retained for compatibility with batch runners."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
class BaseAgent(ABC):
|
||||||
"""Common base for all agents."""
|
"""Common base for non-chat agents still using the old base contract."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -28,190 +27,4 @@ class BaseAgent(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def skills(self) -> list[str]:
|
def skills(self) -> list[str]:
|
||||||
"""Override in subclasses to advertise capabilities."""
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(BaseAgent):
|
|
||||||
"""Base class for LLM-powered chat agents."""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
|
|
||||||
self.tool_results: list[dict] = []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
"""Process a user query and return a text response."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def handle_stream(
|
|
||||||
self, query: str, context: dict[str, Any]
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming variant of handle().
|
|
||||||
|
|
||||||
Default: calls handle() and yields the full response as one chunk.
|
|
||||||
Override in subclasses for true token-level streaming via _tool_loop_stream.
|
|
||||||
"""
|
|
||||||
yield await self.handle(query, context)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
"""Return LangChain tool definitions available to this agent."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def _tool_loop(
|
|
||||||
self,
|
|
||||||
llm: Any,
|
|
||||||
messages: list[Any],
|
|
||||||
tools: list[Any],
|
|
||||||
max_iter: int = 5,
|
|
||||||
) -> str:
|
|
||||||
"""Shared tool-calling loop.
|
|
||||||
|
|
||||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
|
||||||
requesting tool calls or *max_iter* is reached, and returns the
|
|
||||||
final text response. Captures raw execute_on_client results in
|
|
||||||
``self.tool_results``.
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
||||||
|
|
||||||
collector: list[dict] = []
|
|
||||||
set_tool_result_collector(collector)
|
|
||||||
try:
|
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
||||||
|
|
||||||
for _ in range(max_iter):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
return str(response.content)
|
|
||||||
|
|
||||||
# Execute each requested tool call
|
|
||||||
tool_map = {t.name: t for t in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
|
||||||
if tool_fn is None:
|
|
||||||
result = f"Unknown tool: {call['name']}"
|
|
||||||
else:
|
|
||||||
result = await tool_fn.ainvoke(call["args"])
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exhausted iterations — ask model for a final answer without tools
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
return str(response.content)
|
|
||||||
finally:
|
|
||||||
clear_tool_result_collector()
|
|
||||||
self.tool_results = collector
|
|
||||||
|
|
||||||
async def _tool_loop_stream(
|
|
||||||
self,
|
|
||||||
llm: Any,
|
|
||||||
messages: list[Any],
|
|
||||||
tools: list[Any],
|
|
||||||
max_iter: int = 5,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming variant of ``_tool_loop``.
|
|
||||||
|
|
||||||
Behaves identically for tool-calling iterations (uses ainvoke to parse
|
|
||||||
tool calls). For the final response — when the model produces no further
|
|
||||||
tool calls — switches to ``llm.astream()`` and yields text tokens.
|
|
||||||
Captures raw execute_on_client results in ``self.tool_results``.
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
||||||
|
|
||||||
collector: list[dict] = []
|
|
||||||
set_tool_result_collector(collector)
|
|
||||||
try:
|
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
||||||
|
|
||||||
for _ in range(max_iter):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
# Stream the final answer — don't keep the ainvoke result.
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
if chunk.content:
|
|
||||||
yield str(chunk.content)
|
|
||||||
return
|
|
||||||
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
# Execute each requested tool call
|
|
||||||
tool_map = {t.name: t for t in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
|
||||||
if tool_fn is None:
|
|
||||||
result = f"Unknown tool: {call['name']}"
|
|
||||||
else:
|
|
||||||
result = await tool_fn.ainvoke(call["args"])
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exhausted iterations — stream a final answer without tools
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
if chunk.content:
|
|
||||||
yield str(chunk.content)
|
|
||||||
finally:
|
|
||||||
clear_tool_result_collector()
|
|
||||||
self.tool_results = collector
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRegistry:
|
|
||||||
"""Singleton registry for ChatAgent subclasses."""
|
|
||||||
|
|
||||||
_instance: AgentRegistry | None = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._agents: dict[str, type[ChatAgent]] = {}
|
|
||||||
|
|
||||||
def __new__(cls) -> AgentRegistry:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
cls._instance._agents = {}
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
# ── public API ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
|
||||||
"""Class decorator — registers an agent by its name."""
|
|
||||||
instance = agent_class()
|
|
||||||
name = instance.get_name()
|
|
||||||
self._agents[name] = agent_class
|
|
||||||
return agent_class
|
|
||||||
|
|
||||||
def get(self, name: str) -> ChatAgent:
|
|
||||||
"""Return a fresh instance of the named agent."""
|
|
||||||
cls = self._agents.get(name)
|
|
||||||
if cls is None:
|
|
||||||
raise KeyError(f"Agent not found: {name}")
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
def list_agents(self) -> list[dict[str, str]]:
|
|
||||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
|
||||||
result: list[dict[str, str]] = []
|
|
||||||
for cls in self._agents.values():
|
|
||||||
inst = cls()
|
|
||||||
result.append(
|
|
||||||
{"name": inst.get_name(), "description": inst.get_description()}
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def call_agent(
|
|
||||||
self, name: str, query: str, context: dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Instantiate the named agent and call its ``handle`` method."""
|
|
||||||
agent = self.get(name)
|
|
||||||
return await agent.handle(query, context)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
registry = AgentRegistry()
|
|
||||||
|
|||||||
@@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
Drives two agent types:
|
Drives two agent types:
|
||||||
|
|
||||||
* **Local directory agent** — sends an ``agent_run`` frame to the connected
|
* **Local directory agent** — two-phase execution that mirrors the
|
||||||
Electron device, waits for the device to stream back file contents via
|
``deep_agent.py`` tool-calling pattern. Phase 1 (Triage) explores the
|
||||||
``agent_data`` frames, then calls the LLM to extract structured items from
|
user's directory via file-system tools and groups files by project.
|
||||||
each file and pushes inserts to Electron via tool-call round-trips.
|
Phase 2 (Processing) reads full file contents and performs CRUD
|
||||||
|
operations using the standard entity tools (tasks, notes, etc.).
|
||||||
|
|
||||||
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
||||||
Teams, Outlook) and pushes extracted items to Electron. **This path is
|
Teams, Outlook) and pushes extracted items to Electron.
|
||||||
a stub** — provider integrations are implemented in Step 3.6.
|
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
-----
|
-----
|
||||||
@@ -33,11 +33,17 @@ from datetime import datetime, timedelta, timezone
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from croniter import croniter
|
from croniter import croniter
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_llm
|
||||||
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
|
||||||
@@ -45,50 +51,108 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# ── Timeouts ───────────────────────────────────────────────────────────────
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
# Max seconds to wait for Electron to finish streaming file data.
|
# Max seconds to wait for a single tool-call round-trip (FE → BE).
|
||||||
_FILE_READ_TIMEOUT: int = 120
|
_TOOL_CALL_TIMEOUT: int = 30
|
||||||
# Max seconds to wait for Electron to acknowledge a single tool-call insert.
|
# Max LLM reasoning steps per phase.
|
||||||
_INSERT_TIMEOUT: int = 30
|
_MAX_TRIAGE_STEPS: int = 10
|
||||||
|
_MAX_PROCESSING_STEPS: int = 12
|
||||||
|
|
||||||
# ── Allowed tables & extraction schema hints ───────────────────────────────
|
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||||
|
|
||||||
_ALLOWED_TABLES: frozenset[str] = frozenset(
|
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||||
{"tasks", "notes", "timelines", "projects", "taskComments"}
|
"tasks": TASK_TOOLS,
|
||||||
)
|
"projects": PROJECT_TOOLS,
|
||||||
|
"notes": NOTE_TOOLS,
|
||||||
# Field descriptions fed to the extraction LLM as concise schema references.
|
"timelines": TIMELINE_TOOLS,
|
||||||
_TABLE_SCHEMAS: dict[str, str] = {
|
|
||||||
"tasks": (
|
|
||||||
"title (str, required), description (str), "
|
|
||||||
"status (todo|in_progress|done, default todo), "
|
|
||||||
"priority (high|medium|low, default medium), "
|
|
||||||
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
|
|
||||||
),
|
|
||||||
"notes": "title (str, required), content (str, markdown), projectId (str)",
|
|
||||||
"timelines": (
|
|
||||||
"title (str, required), projectId (str, required), date (ms timestamp int)"
|
|
||||||
),
|
|
||||||
"projects": "name (str, required), clientId (str)",
|
|
||||||
"taskComments": "taskId (str, required), author (str), content (str, required)",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_EXTRACTION_SYSTEM_PROMPT = """\
|
# ── Triage prompt ─────────────────────────────────────────────────────────
|
||||||
You are a data extraction assistant for a freelance project management tool.
|
|
||||||
Given a document, extract structured records matching the user's instructions.
|
|
||||||
|
|
||||||
Output a JSON array (no markdown fences, no explanation) of objects shaped:
|
_TRIAGE_SYSTEM_PROMPT = """\
|
||||||
[{{"table": "<table_name>", "data": {{...fields}}}}, ...]
|
You are a file triage assistant for a freelance project management tool.
|
||||||
|
Your job is to explore a local directory on the user's device, understand its
|
||||||
|
structure, and group files by project context.
|
||||||
|
|
||||||
Allowed table names and their fields:
|
You have access to these tools:
|
||||||
{table_schemas}
|
- list_directory: to map folder structure
|
||||||
|
- get_file_metadata: to check creation/modification dates
|
||||||
|
- read_file_content: to read brief snippets when needed for categorisation
|
||||||
|
- list_projects / list_all_projects / get_project: to fetch existing projects
|
||||||
|
from the user's workspace and match files to them
|
||||||
|
|
||||||
Rules:
|
Instructions:
|
||||||
- Only extract tables listed in the "data_types" instructions.
|
1. Start by calling list_directory on the configured root path.
|
||||||
- Use camelCase field names exactly as shown above.
|
2. Explore subdirectories as needed to understand the structure.
|
||||||
- Omit optional fields you cannot determine; do not invent data.
|
3. Use get_file_metadata to check modification dates. Skip files that have
|
||||||
- Never include id, createdAt, updatedAt, isAiSuggested, or isApproved.
|
NOT been modified since: {last_run_at}.
|
||||||
- If nothing relevant is found, return an empty JSON array: []
|
4. Call list_all_projects to get the user's existing projects.
|
||||||
- Return ONLY the JSON array.
|
5. Match files to existing projects by name, folder structure, or content hints.
|
||||||
|
6. If files don't match any existing project, group them under "standalone".
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
|
||||||
|
Target entity types to extract: {data_types}
|
||||||
|
File extensions to consider: {file_extensions}
|
||||||
|
|
||||||
|
When you have finished exploring, output ONLY a JSON object (no markdown
|
||||||
|
fences, no explanation) mapping project IDs or "standalone" to file path
|
||||||
|
arrays:
|
||||||
|
|
||||||
|
{{"<project_id>": ["<file_path>", ...], "standalone": ["<file_path>", ...]}}
|
||||||
|
|
||||||
|
Return ONLY the JSON object as your final message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Processing prompt ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_PROCESSING_BASE_PROMPT = """\
|
||||||
|
You are a data extraction and management assistant for a freelance project
|
||||||
|
management tool.
|
||||||
|
|
||||||
|
Available tools:
|
||||||
|
Filesystem : read_file_content, list_directory, get_file_metadata
|
||||||
|
Tasks : list_tasks, create_task, update_task, add_task_comment
|
||||||
|
Notes : list_notes, get_note, create_note, update_note
|
||||||
|
Timelines : list_timelines, create_timeline, update_timeline
|
||||||
|
Projects : list_all_projects, get_project, create_project, update_project
|
||||||
|
|
||||||
|
Your task:
|
||||||
|
1. Read the full content of each file below using read_file_content.
|
||||||
|
2. For each piece of information found, ALWAYS try to match and update an
|
||||||
|
existing record before creating a new one.
|
||||||
|
3. ONLY act on these entity types: {data_types}.
|
||||||
|
4. Do NOT invent data. Only extract what is clearly present in the files.
|
||||||
|
5. If a file contains no relevant data for the target entity types, skip it.
|
||||||
|
|
||||||
|
Update-first rules (apply in this order):
|
||||||
|
Tasks:
|
||||||
|
- Call list_tasks to find a match by title or context.
|
||||||
|
- If found: call add_task_comment (author "Adiuva"), update_task to set
|
||||||
|
assignees, state (ToDo / In Progress / Completed), or other fields.
|
||||||
|
- If NOT found: call create_task with isAiSuggested=1, isApproved=0.
|
||||||
|
Timelines:
|
||||||
|
- Call list_timelines to find a match by title or date.
|
||||||
|
- If found: call update_timeline to edit fields or mark it complete.
|
||||||
|
- If NOT found: call create_timeline with isAiSuggested=1, isApproved=0.
|
||||||
|
Notes:
|
||||||
|
- Call list_notes to find a match by title or topic, then get_note to
|
||||||
|
read its current content.
|
||||||
|
- If found: call update_note with the merged content.
|
||||||
|
- If NOT found: call create_note with isAiSuggested=1, isApproved=0.
|
||||||
|
Projects:
|
||||||
|
- Call list_all_projects to check for a match first.
|
||||||
|
- Only call create_project if the information is clearly significant and
|
||||||
|
no existing project matches. Set isAiSuggested=1, isApproved=0.
|
||||||
|
|
||||||
|
{project_context}
|
||||||
|
|
||||||
|
Files to process:
|
||||||
|
{file_list}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
|
||||||
|
After processing all files, respond with a brief summary of what you updated
|
||||||
|
and what you created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -118,100 +182,151 @@ def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
|
|||||||
return False # Fail-safe: don't trigger if expression is invalid.
|
return False # Fail-safe: don't trigger if expression is invalid.
|
||||||
|
|
||||||
|
|
||||||
# ── LLM extraction ─────────────────────────────────────────────────────────
|
# ── WS executor for agent context ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _extract_items_from_content(
|
def _make_agent_executor(
|
||||||
prompt_template: str,
|
|
||||||
file_content: str,
|
|
||||||
data_types: list[str],
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Call the LLM to extract structured records from *file_content*.
|
|
||||||
|
|
||||||
Returns a validated list of ``{table: str, data: dict}`` objects.
|
|
||||||
Items referencing tables not in *data_types* are discarded.
|
|
||||||
"""
|
|
||||||
allowed = [t for t in data_types if t in _ALLOWED_TABLES]
|
|
||||||
if not allowed:
|
|
||||||
return []
|
|
||||||
|
|
||||||
schema_text = "\n".join(
|
|
||||||
f" {table}: {_TABLE_SCHEMAS.get(table, '(unknown)')}" for table in allowed
|
|
||||||
)
|
|
||||||
system_prompt = _EXTRACTION_SYSTEM_PROMPT.format(table_schemas=schema_text)
|
|
||||||
user_prompt = (
|
|
||||||
f"User instructions: {prompt_template}\n\n"
|
|
||||||
f"Extract these record types: {', '.join(allowed)}\n\n"
|
|
||||||
f"Document:\n{file_content[:8000]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = get_llm()
|
|
||||||
raw = ""
|
|
||||||
try:
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
|
|
||||||
)
|
|
||||||
raw = str(response.content).strip()
|
|
||||||
items: list[dict] = json.loads(raw)
|
|
||||||
if not isinstance(items, list):
|
|
||||||
raise ValueError("LLM response is not a JSON array")
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
logger.warning(
|
|
||||||
"agent_runner: LLM extraction returned invalid JSON: %s — snippet: %.200r",
|
|
||||||
exc,
|
|
||||||
raw,
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
# Other exceptions (LLM API errors, network errors) propagate to the
|
|
||||||
# caller (run_local_agent) which records them per-file in the run log.
|
|
||||||
|
|
||||||
validated: list[dict[str, Any]] = []
|
|
||||||
for item in items:
|
|
||||||
table = item.get("table")
|
|
||||||
data = item.get("data")
|
|
||||||
if not isinstance(table, str) or table not in allowed:
|
|
||||||
continue
|
|
||||||
if not isinstance(data, dict) or not data:
|
|
||||||
continue
|
|
||||||
# Strip any server-generated or forbidden fields.
|
|
||||||
for _field in ("id", "createdAt", "updatedAt", "isAiSuggested", "isApproved"):
|
|
||||||
data.pop(_field, None)
|
|
||||||
validated.append({"table": table, "data": data})
|
|
||||||
return validated
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tool-call insert helper ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _send_insert_to_client(
|
|
||||||
user_id: str,
|
user_id: str,
|
||||||
table: str,
|
|
||||||
data: dict[str, Any],
|
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
) -> dict[str, Any]:
|
run_context: dict | None = None,
|
||||||
"""Send an ``insert`` tool_call frame to Electron and await the tool_result.
|
) -> Any:
|
||||||
|
"""Create a WS callback for ``set_client_executor()`` so that all tools
|
||||||
|
can use ``execute_on_client()`` during an agent run.
|
||||||
|
|
||||||
All inserts include ``isAiSuggested=1, isApproved=0`` so the user can
|
If *run_context* is provided it is attached to every ``tool_call`` frame
|
||||||
review AI-produced records before they are treated as confirmed.
|
so the Electron client can attribute actions to the correct agent run.
|
||||||
|
|
||||||
Raises ``asyncio.TimeoutError`` if Electron does not respond within
|
|
||||||
``_INSERT_TIMEOUT`` seconds. Raises ``RuntimeError`` if the device
|
|
||||||
disconnects before the frame can be sent.
|
|
||||||
"""
|
"""
|
||||||
call_id = str(uuid.uuid4())
|
async def _executor(payload: dict) -> dict:
|
||||||
payload: dict[str, Any] = {
|
payload["type"] = "tool_call"
|
||||||
"type": "tool_call",
|
if run_context:
|
||||||
"id": call_id,
|
payload["run_context"] = run_context
|
||||||
"action": "insert",
|
call_id = payload["id"]
|
||||||
"table": table,
|
|
||||||
"data": {**data, "isAiSuggested": 1, "isApproved": 0},
|
|
||||||
}
|
|
||||||
fut = device_mgr.create_pending_call(user_id, call_id)
|
fut = device_mgr.create_pending_call(user_id, call_id)
|
||||||
await device_mgr.send_frame(user_id, payload)
|
await device_mgr.send_frame(user_id, payload)
|
||||||
return await asyncio.wait_for(fut, timeout=_INSERT_TIMEOUT)
|
return await asyncio.wait_for(fut, timeout=_TOOL_CALL_TIMEOUT)
|
||||||
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent runner ──────────────────────────────────────────────────────
|
# ── LLM tool-calling loop (mirrors deep_agent._run_single_agent) ──────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_agent_with_tools(
|
||||||
|
*,
|
||||||
|
system_prompt: str,
|
||||||
|
user_message: str,
|
||||||
|
tools: list[Any],
|
||||||
|
max_steps: int,
|
||||||
|
) -> str:
|
||||||
|
"""Run an LLM agent with tool-calling, returning the final text response.
|
||||||
|
|
||||||
|
Follows the same pattern as ``deep_agent._run_single_agent``:
|
||||||
|
bind tools → invoke → handle tool calls → repeat until final text.
|
||||||
|
"""
|
||||||
|
llm = get_llm()
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content=user_message),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max steps, get final response without tools.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Triage map parser ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_triage_map(raw: str) -> dict[str, list[str]] | None:
|
||||||
|
"""Extract the JSON triage map from the LLM's final response."""
|
||||||
|
text = raw.strip()
|
||||||
|
# Try direct parse first.
|
||||||
|
try:
|
||||||
|
parsed = json.loads(text)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return {k: v for k, v in parsed.items() if isinstance(v, list)}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try extracting JSON from markdown fences or surrounding text.
|
||||||
|
import re
|
||||||
|
match = re.search(r"\{[\s\S]*\}", text)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(match.group(0))
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return {k: v for k, v in parsed.items() if isinstance(v, list)}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool list builder ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
||||||
|
"""Build the tool list for Phase 2 based on user's data_types selection."""
|
||||||
|
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
||||||
|
for dt in data_types:
|
||||||
|
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
||||||
|
if dt_tools:
|
||||||
|
tools.extend(dt_tools)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent runner (two-phase) ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def run_local_agent(
|
async def run_local_agent(
|
||||||
@@ -219,144 +334,163 @@ async def run_local_agent(
|
|||||||
config: LocalAgentConfig,
|
config: LocalAgentConfig,
|
||||||
run_log: AgentRunLog,
|
run_log: AgentRunLog,
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
|
run_context: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a local directory agent run end-to-end.
|
"""Execute a local directory agent run using two-phase LLM-with-tools.
|
||||||
|
|
||||||
Steps:
|
Phase 1 — Triage:
|
||||||
|
Explore the directory structure, check metadata, match files to
|
||||||
|
existing projects. Output: a JSON map of project → file paths.
|
||||||
|
|
||||||
1. Verify the device identified by ``config.device_id`` is currently online.
|
Phase 2 — Processing:
|
||||||
2. Pre-create the agent_data queue so no incoming frames are lost.
|
For each project group, read full file contents and perform CRUD
|
||||||
3. Send ``agent_run`` frame to Electron (paths, extensions, prompt, data_types).
|
operations using the standard entity tools.
|
||||||
4. Consume ``agent_data`` frames until the ``None`` sentinel from
|
|
||||||
``agent_complete``.
|
|
||||||
5. For each received file call the LLM to extract ``{table, data}`` items.
|
|
||||||
6. Push each item to Electron as an ``insert`` tool-call; include
|
|
||||||
``isAiSuggested=1, isApproved=0`` so users can review AI suggestions.
|
|
||||||
7. Persist the run outcome (status, counts, errors) and update
|
|
||||||
``config.last_run_at``.
|
|
||||||
"""
|
"""
|
||||||
run_id = run_log.id
|
run_id = run_log.id
|
||||||
|
|
||||||
# ── 1. Device online check ─────────────────────────────────────────
|
# ── Device online check ─────────────────────────────────────────
|
||||||
if not device_mgr.is_online(user_id, config.device_id):
|
target_device_id = config.device_id.strip() if isinstance(config.device_id, str) else ""
|
||||||
|
if target_device_id:
|
||||||
|
is_online = device_mgr.is_online(user_id, target_device_id)
|
||||||
|
else:
|
||||||
|
is_online = device_mgr.is_online(user_id)
|
||||||
|
|
||||||
|
if not is_online:
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: skip run=%s — device %r offline for user=%s",
|
"agent_runner: skip run=%s — device %r offline for user=%s",
|
||||||
run_id,
|
run_id,
|
||||||
config.device_id,
|
target_device_id or "<any>",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
await _finalize_run(
|
await _finalize_run(
|
||||||
run_log,
|
run_log,
|
||||||
status="error",
|
status="error",
|
||||||
errors=[f"Device {config.device_id!r} is not connected"],
|
errors=[f"Device {target_device_id or '<any>'!r} is not connected"],
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# ── 2. Pre-create agent_data queue ────────────────────────────────
|
# ── Set up WS executor for tools ────────────────────────────────
|
||||||
try:
|
executor = _make_agent_executor(user_id, device_mgr, run_context)
|
||||||
device_mgr.get_agent_data_queue(user_id, run_id)
|
set_client_executor(executor)
|
||||||
except RuntimeError:
|
|
||||||
await _finalize_run(
|
|
||||||
run_log,
|
|
||||||
status="error",
|
|
||||||
errors=["Device disconnected before agent run could start"],
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# ── 3. Send agent_run frame ────────────────────────────────────────
|
|
||||||
frame: dict[str, Any] = {
|
|
||||||
"type": "agent_run",
|
|
||||||
"run_id": run_id,
|
|
||||||
"agent_id": config.id,
|
|
||||||
"config": {
|
|
||||||
"paths": config.directory_paths,
|
|
||||||
"file_extensions": config.file_extensions,
|
|
||||||
"prompt_template": config.prompt_template,
|
|
||||||
"data_types": config.data_types,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
await device_mgr.send_frame(user_id, frame)
|
|
||||||
except RuntimeError as exc:
|
|
||||||
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
|
||||||
await _finalize_run(
|
|
||||||
run_log,
|
|
||||||
status="error",
|
|
||||||
errors=[f"Failed to send agent_run frame: {exc}"],
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"agent_runner: sent agent_run run=%s agent=%s user=%s",
|
|
||||||
run_id,
|
|
||||||
config.id,
|
|
||||||
user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── 4. Consume agent_data frames ──────────────────────────────────
|
|
||||||
files: list[dict[str, Any]] = []
|
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
|
|
||||||
try:
|
|
||||||
queue = device_mgr.get_agent_data_queue(user_id, run_id)
|
|
||||||
deadline = asyncio.get_event_loop().time() + _FILE_READ_TIMEOUT
|
|
||||||
while True:
|
|
||||||
remaining = deadline - asyncio.get_event_loop().time()
|
|
||||||
if remaining <= 0:
|
|
||||||
errors.append("Timed out waiting for file data from device")
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
frame_data = await asyncio.wait_for(queue.get(), timeout=remaining)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
errors.append("Timed out waiting for file data from device")
|
|
||||||
break
|
|
||||||
if frame_data is None:
|
|
||||||
# Sentinel from agent_complete — stream is done.
|
|
||||||
break
|
|
||||||
files.extend(frame_data.get("files", []))
|
|
||||||
except RuntimeError as exc:
|
|
||||||
errors.append(f"Queue error reading agent data: {exc}")
|
|
||||||
|
|
||||||
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
|
||||||
items_processed = 0
|
items_processed = 0
|
||||||
items_created = 0
|
items_created = 0
|
||||||
|
|
||||||
for file_info in files:
|
|
||||||
file_path: str = file_info.get("path", "<unknown>")
|
|
||||||
content: str = file_info.get("content", "")
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
items_processed += 1
|
|
||||||
try:
|
try:
|
||||||
extracted = await _extract_items_from_content(
|
# ── Phase 1: Triage ─────────────────────────────────────────
|
||||||
config.prompt_template, content, config.data_types
|
logger.info("agent_runner: run=%s phase=triage start user=%s", run_id, user_id)
|
||||||
|
|
||||||
|
last_run_str = "never (process all files)"
|
||||||
|
if config.last_run_at:
|
||||||
|
last_run_str = config.last_run_at.isoformat()
|
||||||
|
|
||||||
|
custom_section = ""
|
||||||
|
if config.prompt_template:
|
||||||
|
custom_section = f"User instructions:\n{config.prompt_template}"
|
||||||
|
|
||||||
|
file_ext_str = ", ".join(config.file_extensions) if config.file_extensions else "all"
|
||||||
|
|
||||||
|
triage_prompt = _TRIAGE_SYSTEM_PROMPT.format(
|
||||||
|
last_run_at=last_run_str,
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
|
data_types=", ".join(config.data_types),
|
||||||
|
file_extensions=file_ext_str,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
|
||||||
errors.append(f"LLM extraction error for {file_path!r}: {exc}")
|
directory_paths = config.directory_paths
|
||||||
|
triage_user_msg = (
|
||||||
|
f"Explore these directories and produce the triage map:\n"
|
||||||
|
f"{json.dumps(directory_paths, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
triage_tools: list[Any] = list(FILESYSTEM_TOOLS) + list(PROJECT_TOOLS)
|
||||||
|
|
||||||
|
triage_response = await _run_agent_with_tools(
|
||||||
|
system_prompt=triage_prompt,
|
||||||
|
user_message=triage_user_msg,
|
||||||
|
tools=triage_tools,
|
||||||
|
max_steps=_MAX_TRIAGE_STEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
triage_map = _parse_triage_map(triage_response)
|
||||||
|
if not triage_map:
|
||||||
|
errors.append(f"Triage phase failed to produce a valid file map: {triage_response[:500]}")
|
||||||
|
await _finalize_run(run_log, status="error", errors=errors)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s triage complete groups=%d total_files=%d",
|
||||||
|
run_id,
|
||||||
|
len(triage_map),
|
||||||
|
sum(len(files) for files in triage_map.values()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Phase 2: Processing (per group) ─────────────────────────
|
||||||
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
|
|
||||||
|
for group_key, file_paths in triage_map.items():
|
||||||
|
if not file_paths:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for item in extracted:
|
logger.info(
|
||||||
try:
|
"agent_runner: run=%s phase=processing group=%s files=%d",
|
||||||
result = await _send_insert_to_client(
|
run_id,
|
||||||
user_id, item["table"], item["data"], device_mgr
|
group_key,
|
||||||
)
|
len(file_paths),
|
||||||
if result.get("error"):
|
|
||||||
errors.append(
|
|
||||||
f"Insert failed ({item['table']}, {file_path!r}): {result['error']}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build project context for the LLM.
|
||||||
|
if group_key == "standalone":
|
||||||
|
project_context = "These files are not associated with any existing project."
|
||||||
else:
|
else:
|
||||||
items_created += 1
|
project_context = f"These files belong to project ID: {group_key}. Use this project_id when creating records."
|
||||||
except asyncio.TimeoutError:
|
|
||||||
errors.append(
|
file_list_str = "\n".join(f"- {fp}" for fp in file_paths)
|
||||||
f"Timed out awaiting insert ack ({item['table']}, {file_path!r})"
|
|
||||||
|
processing_prompt = _PROCESSING_BASE_PROMPT.format(
|
||||||
|
data_types=", ".join(config.data_types),
|
||||||
|
project_context=project_context,
|
||||||
|
file_list=file_list_str,
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
)
|
)
|
||||||
except RuntimeError as exc:
|
|
||||||
errors.append(f"Insert error ({item['table']}, {file_path!r}): {exc}")
|
|
||||||
|
|
||||||
# ── 7. Finalise ────────────────────────────────────────────────────
|
items_processed += len(file_paths)
|
||||||
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
|
||||||
|
|
||||||
if errors and items_created == 0:
|
try:
|
||||||
|
result_text = await _run_agent_with_tools(
|
||||||
|
system_prompt=processing_prompt,
|
||||||
|
user_message="Process the listed files now.",
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s group=%s processing_result=%s",
|
||||||
|
run_id,
|
||||||
|
group_key,
|
||||||
|
result_text[:500],
|
||||||
|
)
|
||||||
|
# Count created items by scanning tool call results.
|
||||||
|
# The tools themselves handle creation; we estimate from the
|
||||||
|
# summary. A more precise count would require intercepting
|
||||||
|
# tool results, but the summary is sufficient for the run log.
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Processing error for group '{group_key}': {exc}")
|
||||||
|
logger.error(
|
||||||
|
"agent_runner: run=%s group=%s processing failed: %s",
|
||||||
|
run_id,
|
||||||
|
group_key,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
logger.error("agent_runner: run=%s failed: %s", run_id, exc)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
final_status = "error"
|
final_status = "error"
|
||||||
elif errors:
|
elif errors:
|
||||||
final_status = "partial"
|
final_status = "partial"
|
||||||
@@ -369,19 +503,30 @@ async def run_local_agent(
|
|||||||
items_processed=items_processed,
|
items_processed=items_processed,
|
||||||
items_created=items_created,
|
items_created=items_created,
|
||||||
errors=errors,
|
errors=errors,
|
||||||
update_config_last_run=True,
|
update_config_last_run=False,
|
||||||
config_id=config.id,
|
config_id=config.id,
|
||||||
config_type="local",
|
config_type="local",
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
"agent_runner: run=%s done status=%s processed=%d errors=%d",
|
||||||
run_id,
|
run_id,
|
||||||
final_status,
|
final_status,
|
||||||
items_processed,
|
items_processed,
|
||||||
items_created,
|
|
||||||
len(errors),
|
len(errors),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Notify the Electron client that the run is complete so it can close
|
||||||
|
# the run record in its local SQLite.
|
||||||
|
if run_context and device_mgr.is_online(user_id):
|
||||||
|
try:
|
||||||
|
await device_mgr.send_frame(user_id, {
|
||||||
|
"type": "run_complete",
|
||||||
|
"run_context": run_context,
|
||||||
|
"status": final_status,
|
||||||
|
})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_id, exc)
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -405,8 +550,7 @@ async def run_cloud_agent(
|
|||||||
3. Instantiate the provider client (Gmail or MS Graph).
|
3. Instantiate the provider client (Gmail or MS Graph).
|
||||||
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
|
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
|
||||||
the first run) applying ``config.filter_config`` filters.
|
the first run) applying ``config.filter_config`` filters.
|
||||||
5. For each message/email call ``_extract_items_from_content`` with
|
5. For each message/email call the LLM to extract structured items.
|
||||||
``config.prompt_template`` to get structured ``{table, data}`` items.
|
|
||||||
6. Push each item to Electron as an ``insert`` tool-call.
|
6. Push each item to Electron as an ``insert`` tool-call.
|
||||||
7. If the provider refreshed its access token, re-encrypt and write it
|
7. If the provider refreshed its access token, re-encrypt and write it
|
||||||
back to ``config.oauth_token_encrypted``.
|
back to ``config.oauth_token_encrypted``.
|
||||||
@@ -514,37 +658,40 @@ async def run_cloud_agent(
|
|||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
# ── 5–6. Extract + insert via LLM with tools ─────────────────────
|
||||||
|
executor = _make_agent_executor(user_id, device_mgr)
|
||||||
|
set_client_executor(executor)
|
||||||
|
|
||||||
|
try:
|
||||||
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
|
custom_section = ""
|
||||||
|
if config.prompt_template:
|
||||||
|
custom_section = f"User instructions:\n{config.prompt_template}"
|
||||||
|
|
||||||
for msg in raw_messages:
|
for msg in raw_messages:
|
||||||
content_text = msg.as_text
|
content_text = msg.as_text
|
||||||
if not content_text:
|
if not content_text:
|
||||||
continue
|
continue
|
||||||
items_processed += 1
|
items_processed += 1
|
||||||
|
|
||||||
|
processing_prompt = _PROCESSING_BASE_PROMPT.format(
|
||||||
|
data_types=", ".join(config.data_types),
|
||||||
|
project_context="Determine the appropriate project from the message context.",
|
||||||
|
file_list=f"Message from {config.provider} (id: {msg.id})",
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extracted = await _extract_items_from_content(
|
await _run_agent_with_tools(
|
||||||
config.prompt_template, content_text, config.data_types
|
system_prompt=processing_prompt,
|
||||||
|
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
errors.append(f"LLM extraction error for message {msg.id!r}: {exc}")
|
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
||||||
continue
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
for item in extracted:
|
|
||||||
try:
|
|
||||||
result = await _send_insert_to_client(
|
|
||||||
user_id, item["table"], item["data"], device_mgr
|
|
||||||
)
|
|
||||||
if result.get("error"):
|
|
||||||
errors.append(
|
|
||||||
f"Insert failed ({item['table']}, msg={msg.id!r}): {result['error']}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
items_created += 1
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
errors.append(
|
|
||||||
f"Timed out awaiting insert ack ({item['table']}, msg={msg.id!r})"
|
|
||||||
)
|
|
||||||
except RuntimeError as exc:
|
|
||||||
errors.append(f"Insert error ({item['table']}, msg={msg.id!r}): {exc}")
|
|
||||||
|
|
||||||
# ── 7. Persist refreshed token (if any) ───────────────────────────
|
# ── 7. Persist refreshed token (if any) ───────────────────────────
|
||||||
refreshed = getattr(provider, "refreshed_credentials", None)
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
@@ -610,61 +757,12 @@ async def trigger_pending_runs(
|
|||||||
* Runs execute **sequentially** to avoid flooding the WS connection.
|
* Runs execute **sequentially** to avoid flooding the WS connection.
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id
|
"agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
)
|
)
|
||||||
async with async_session() as db:
|
|
||||||
local_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
LocalAgentConfig.enabled == True, # noqa: E712
|
|
||||||
LocalAgentConfig.device_id == device_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
|
|
||||||
|
|
||||||
cloud_result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
CloudAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
|
|
||||||
|
|
||||||
# Build ordered list of overdue (type, config) pairs.
|
|
||||||
pending: list[tuple[str, Any]] = []
|
|
||||||
for cfg in local_configs:
|
|
||||||
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
|
||||||
pending.append(("local", cfg))
|
|
||||||
for cfg in cloud_configs:
|
|
||||||
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
|
||||||
pending.append(("cloud", cfg))
|
|
||||||
|
|
||||||
if not pending:
|
|
||||||
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
for agent_type, cfg in pending:
|
|
||||||
# Create a fresh run log for this scheduled dispatch.
|
|
||||||
run_log = AgentRunLog(
|
|
||||||
agent_id=cfg.id,
|
|
||||||
agent_type=agent_type,
|
|
||||||
user_id=user_id,
|
|
||||||
status="running",
|
|
||||||
)
|
|
||||||
async with async_session() as db:
|
|
||||||
db.add(run_log)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(run_log)
|
|
||||||
|
|
||||||
if agent_type == "local":
|
|
||||||
await run_local_agent(user_id, cfg, run_log, device_mgr)
|
|
||||||
else:
|
|
||||||
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helper ─────────────────────────────────────────────────────────
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
846
app/core/deep_agent.py
Normal file
846
app/core/deep_agent.py
Normal file
@@ -0,0 +1,846 @@
|
|||||||
|
"""Single-agent runners for home and floating chat contexts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import date
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||||
|
from app.db import async_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||||
|
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||||
|
|
||||||
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
||||||
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
|
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||||
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
||||||
|
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
||||||
|
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
||||||
|
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
||||||
|
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
||||||
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||||
|
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
||||||
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
|
||||||
|
"You are a strict domain classifier for websocket floating requests. "
|
||||||
|
"Return ONLY a JSON object with keys: type, id, section. "
|
||||||
|
"Allowed type values: task, timeline, project, node. "
|
||||||
|
"Allowed section values: task, timeline, note, or null. "
|
||||||
|
"Rules: infer from user message intent first; do not blindly trust scope.type. "
|
||||||
|
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
|
||||||
|
"If project id is unknown but context.resolved_project_id exists, use it as id. "
|
||||||
|
"If id is unknown, use null. "
|
||||||
|
"No markdown, no prose, JSON only."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _candidate_tokens(message: str) -> list[str]:
|
||||||
|
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
||||||
|
return [token for token in tokens if len(token) >= 3]
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_project_id_from_message(message: str) -> str | None:
|
||||||
|
"""Resolve likely project UUID from user message using client project list."""
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not isinstance(rows, list) or not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tokens = _candidate_tokens(message)
|
||||||
|
scored: list[tuple[int, dict[str, Any]]] = []
|
||||||
|
for row in rows:
|
||||||
|
if not isinstance(row, dict):
|
||||||
|
continue
|
||||||
|
name = str(row.get("name", "")).lower()
|
||||||
|
score = sum(1 for token in tokens if token in name)
|
||||||
|
if score > 0:
|
||||||
|
scored.append((score, row))
|
||||||
|
|
||||||
|
if not scored:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scored.sort(key=lambda item: item[0], reverse=True)
|
||||||
|
top_score = scored[0][0]
|
||||||
|
top_rows = [row for score, row in scored if score == top_score]
|
||||||
|
if len(top_rows) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
project_id = top_rows[0].get("id")
|
||||||
|
return project_id if isinstance(project_id, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _needs_project_resolution(message: str) -> bool:
|
||||||
|
lowered = message.lower()
|
||||||
|
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
prepared = dict(context)
|
||||||
|
if _needs_project_resolution(message):
|
||||||
|
resolved_project_id = await _resolve_project_id_from_message(message)
|
||||||
|
if resolved_project_id:
|
||||||
|
prepared["resolved_project_id"] = resolved_project_id
|
||||||
|
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
||||||
|
return prepared
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools() -> list[Any]:
|
||||||
|
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||||
|
debug = context.get("_debug")
|
||||||
|
if isinstance(debug, dict):
|
||||||
|
request_id = debug.get("request_id")
|
||||||
|
if isinstance(request_id, str) and request_id:
|
||||||
|
return request_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
sanitized = dict(context)
|
||||||
|
sanitized.pop("_debug", None)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
||||||
|
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_upcoming_timeline_query(message: str) -> bool:
|
||||||
|
lowered = message.lower()
|
||||||
|
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
|
||||||
|
has_timeline_topic = any(
|
||||||
|
token in lowered
|
||||||
|
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
|
||||||
|
)
|
||||||
|
return has_upcoming and has_timeline_topic
|
||||||
|
|
||||||
|
|
||||||
|
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
|
||||||
|
match = _TIMELINE_DMY_RE.search(dmy)
|
||||||
|
if not match:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
parsed = date(
|
||||||
|
int(match.group("y")),
|
||||||
|
int(match.group("m")),
|
||||||
|
int(match.group("d")),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return True
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
return parsed >= today and parsed.year == today.year and parsed.month == today.month
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
upcoming_timeline_only = _is_upcoming_timeline_query(message)
|
||||||
|
output_lines: list[str] = []
|
||||||
|
|
||||||
|
for line in text.splitlines():
|
||||||
|
matches = list(_TAG_LINE_RE.finditer(line))
|
||||||
|
if not matches:
|
||||||
|
output_lines.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
|
||||||
|
if not had_non_tag_text and len(matches) == 1:
|
||||||
|
tag_text = matches[0].group(0)
|
||||||
|
if (
|
||||||
|
upcoming_timeline_only
|
||||||
|
and "<timeline>" in tag_text
|
||||||
|
and not _timeline_date_in_current_month_or_future(line)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
output_lines.append(tag_text)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
tag_text = match.group(0)
|
||||||
|
if (
|
||||||
|
upcoming_timeline_only
|
||||||
|
and "<timeline>" in tag_text
|
||||||
|
and not _timeline_date_in_current_month_or_future(line)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
output_lines.append(tag_text)
|
||||||
|
|
||||||
|
return "\n".join(output_lines)
|
||||||
|
|
||||||
|
|
||||||
|
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
||||||
|
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
||||||
|
_FLOATING_EMPTY_FALLBACK = "No results found."
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup_fragment(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
cleaned = _GENERIC_TAG_RE.sub("", text)
|
||||||
|
return _BRACKETED_ID_RE.sub("", cleaned)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup(text: str) -> str:
|
||||||
|
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
cleaned = _strip_floating_markup_fragment(text)
|
||||||
|
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
|
||||||
|
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
||||||
|
return "\n".join(line for line in lines if line)
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
||||||
|
fallback = _strip_floating_markup_fragment(raw_text or "")
|
||||||
|
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
||||||
|
return fallback or _FLOATING_EMPTY_FALLBACK
|
||||||
|
|
||||||
|
|
||||||
|
class _FloatingStreamSanitizer:
|
||||||
|
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._pending = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
||||||
|
boundary = len(text)
|
||||||
|
|
||||||
|
last_lt = text.rfind("<")
|
||||||
|
if last_lt != -1 and ">" not in text[last_lt:]:
|
||||||
|
boundary = min(boundary, last_lt)
|
||||||
|
|
||||||
|
last_lb = text.rfind("[")
|
||||||
|
if last_lb != -1 and "]" not in text[last_lb:]:
|
||||||
|
boundary = min(boundary, last_lb)
|
||||||
|
|
||||||
|
if boundary == len(text):
|
||||||
|
return text, ""
|
||||||
|
return text[:boundary], text[boundary:]
|
||||||
|
|
||||||
|
def feed(self, chunk: str) -> str:
|
||||||
|
combined = f"{self._pending}{chunk}"
|
||||||
|
safe_text, self._pending = self._split_safe_boundary(combined)
|
||||||
|
return _strip_floating_markup_fragment(safe_text)
|
||||||
|
|
||||||
|
def finalize(self) -> str:
|
||||||
|
# Drop dangling unfinished wrappers at the very end.
|
||||||
|
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
||||||
|
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
||||||
|
self._pending = ""
|
||||||
|
return _strip_floating_markup_fragment(tail)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_memory_label(path_or_label: str) -> str:
|
||||||
|
value = path_or_label.strip()
|
||||||
|
if value.startswith("/memories/"):
|
||||||
|
value = value[len("/memories/"):]
|
||||||
|
value = value.strip("/")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
@tool
|
||||||
|
async def memory_list_blocks() -> str:
|
||||||
|
"""List all core memory blocks currently stored for the user."""
|
||||||
|
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
blocks = await memory.list_core_blocks(user_id)
|
||||||
|
if not blocks:
|
||||||
|
return "No memory blocks found."
|
||||||
|
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
||||||
|
return "Memory blocks:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_get(path_or_label: str) -> str:
|
||||||
|
"""Get one memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
value = await memory.get_core_block(user_id, label)
|
||||||
|
if value is None:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}':\n{value}"
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_create(path_or_label: str, value: str) -> str:
|
||||||
|
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
||||||
|
return f"Memory block '{label}' saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_append(path_or_label: str, content: str) -> str:
|
||||||
|
"""Append content to a memory block, creating it if missing."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.append_core(user_id, label, content)
|
||||||
|
return f"Memory block '{label}' appended."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
||||||
|
"""Replace one exact string in a memory block."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
||||||
|
if not changed:
|
||||||
|
return f"No replacement made in '{label}' (old string not found)."
|
||||||
|
return f"Memory block '{label}' updated."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_delete(path_or_label: str) -> str:
|
||||||
|
"""Delete a memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
deleted = await memory.delete_core(user_id, label)
|
||||||
|
if not deleted:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}' deleted."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_insert(content: str) -> str:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.insert_archival(user_id, content, source="assistant")
|
||||||
|
return "Archival memory saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
||||||
|
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_archival(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No archival memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Archival memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def conversation_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search recall memory from prior episodic conversation summaries."""
|
||||||
|
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_recall(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No recall memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Recall memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
return [
|
||||||
|
memory_list_blocks,
|
||||||
|
memory_get,
|
||||||
|
memory_create,
|
||||||
|
memory_append,
|
||||||
|
memory_replace,
|
||||||
|
memory_delete,
|
||||||
|
archival_memory_insert,
|
||||||
|
archival_memory_search,
|
||||||
|
conversation_search,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
|
||||||
|
lowered = message.lower()
|
||||||
|
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
||||||
|
return "timeline"
|
||||||
|
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
|
||||||
|
return "task"
|
||||||
|
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
||||||
|
return "note"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
|
||||||
|
type_raw = str(payload.get("type") or "").strip().lower()
|
||||||
|
domain_type: FloatingDomainType = "task"
|
||||||
|
if type_raw in {"task", "timeline", "project", "node"}:
|
||||||
|
domain_type = type_raw
|
||||||
|
|
||||||
|
id_value = payload.get("id")
|
||||||
|
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
|
||||||
|
if domain_type == "project" and not domain_id:
|
||||||
|
domain_id = fallback_id
|
||||||
|
|
||||||
|
section_raw = payload.get("section")
|
||||||
|
section: FloatingDomainSection | None = None
|
||||||
|
if isinstance(section_raw, str):
|
||||||
|
section_candidate = section_raw.strip().lower()
|
||||||
|
if section_candidate in {"task", "timeline", "note"}:
|
||||||
|
section = section_candidate
|
||||||
|
|
||||||
|
if domain_type != "project":
|
||||||
|
section = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": domain_type,
|
||||||
|
"id": domain_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_object(text: str) -> dict[str, Any] | None:
|
||||||
|
raw = text.strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(match.group(0))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||||
|
section = _detect_domain_section(message)
|
||||||
|
scope = context.get("scope") if isinstance(context, dict) else None
|
||||||
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||||
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||||
|
|
||||||
|
if isinstance(scope, dict):
|
||||||
|
scope_type = str(scope.get("type") or "").strip().lower()
|
||||||
|
scope_id = scope.get("id")
|
||||||
|
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
|
||||||
|
|
||||||
|
if scope_type in {"task", "tasks"}:
|
||||||
|
return {"type": "task", "id": scope_id_value, "section": None}
|
||||||
|
if scope_type in {"project", "projects"}:
|
||||||
|
project_scope_id = scope_id_value or project_id
|
||||||
|
return {
|
||||||
|
"type": "project",
|
||||||
|
"id": project_scope_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
if scope_type in {"note", "notes"}:
|
||||||
|
return {
|
||||||
|
"type": "node",
|
||||||
|
"id": scope_id_value,
|
||||||
|
"section": None,
|
||||||
|
}
|
||||||
|
if scope_type in {"timeline", "timelines"}:
|
||||||
|
return {"type": "timeline", "id": scope_id_value, "section": None}
|
||||||
|
|
||||||
|
lowered = message.lower()
|
||||||
|
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
|
||||||
|
return {
|
||||||
|
"type": "project",
|
||||||
|
"id": project_id,
|
||||||
|
"section": section,
|
||||||
|
}
|
||||||
|
if section == "timeline":
|
||||||
|
return {"type": "timeline", "id": None, "section": None}
|
||||||
|
if section == "note":
|
||||||
|
return {"type": "node", "id": None, "section": None}
|
||||||
|
return {"type": "task", "id": None, "section": None}
|
||||||
|
|
||||||
|
|
||||||
|
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||||
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||||
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||||
|
|
||||||
|
classifier_context = {
|
||||||
|
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
|
||||||
|
"resolved_project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_llm()
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[
|
||||||
|
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"Message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
parsed = _parse_json_object(_as_text(response.content))
|
||||||
|
if parsed is not None:
|
||||||
|
domain = _normalize_domain_payload(parsed, project_id)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
|
||||||
|
domain.get("type"),
|
||||||
|
domain.get("id"),
|
||||||
|
domain.get("section"),
|
||||||
|
)
|
||||||
|
return domain
|
||||||
|
logger.warning("deep_agent: floating_domain classifier returned non-json output")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
|
||||||
|
|
||||||
|
return _infer_floating_domain_rule_based(message, context)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_single_agent(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
system_prompt: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_steps: int = 6,
|
||||||
|
) -> str:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
llm = get_llm()
|
||||||
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"User message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
collected: list[dict[str, Any]] = []
|
||||||
|
set_tool_result_collector(collected)
|
||||||
|
try:
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
final_text = _as_text(response.content)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
len(final_text),
|
||||||
|
)
|
||||||
|
return final_text
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
final_text = _as_text(final.content)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
len(final_text),
|
||||||
|
)
|
||||||
|
return final_text
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_single_agent_stream(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
system_prompt: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_steps: int = 6,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
llm = get_llm()
|
||||||
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
f"User message:\n{message}\n\n"
|
||||||
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
streamed_chars = 0
|
||||||
|
collected: list[dict[str, Any]] = []
|
||||||
|
set_tool_result_collector(collected)
|
||||||
|
try:
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
emitted_any = False
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
streamed_chars += len(token)
|
||||||
|
emitted_any = True
|
||||||
|
yield "token", token
|
||||||
|
|
||||||
|
# Some providers return final text in `response.content` but stream no chunks.
|
||||||
|
if not emitted_any:
|
||||||
|
fallback_text = _as_text(response.content)
|
||||||
|
if fallback_text:
|
||||||
|
streamed_chars += len(fallback_text)
|
||||||
|
yield "token", fallback_text
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
streamed_chars,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
streamed_chars += len(token)
|
||||||
|
yield "token", token
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tool_calls_count,
|
||||||
|
streamed_chars,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
response = await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
)
|
||||||
|
return _normalize_tagged_list_lines(response, message)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
domain = await _infer_floating_domain(message, prepared_context)
|
||||||
|
response = await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
)
|
||||||
|
sanitized = _strip_floating_markup(response)
|
||||||
|
if not sanitized and response:
|
||||||
|
sanitized = _fallback_from_raw_floating_text(response)
|
||||||
|
return sanitized, domain
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
text_chunks: list[str] = []
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
):
|
||||||
|
event_type, data = event
|
||||||
|
if event_type != "token":
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
text_chunks.append(str(data or ""))
|
||||||
|
|
||||||
|
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
|
||||||
|
if normalized:
|
||||||
|
yield "token", normalized
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
domain = await _infer_floating_domain(message, prepared_context)
|
||||||
|
yield "floating_domain", domain
|
||||||
|
|
||||||
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
|
emitted_sanitized = False
|
||||||
|
raw_chunks: list[str] = []
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||||
|
message=message,
|
||||||
|
context=prepared_context,
|
||||||
|
):
|
||||||
|
event_type, data = event
|
||||||
|
if event_type != "token":
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_chunk = str(data or "")
|
||||||
|
raw_chunks.append(raw_chunk)
|
||||||
|
sanitized_chunk = sanitizer.feed(raw_chunk)
|
||||||
|
if sanitized_chunk:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", sanitized_chunk
|
||||||
|
|
||||||
|
tail = sanitizer.finalize()
|
||||||
|
if tail:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", tail
|
||||||
|
|
||||||
|
if not emitted_sanitized and raw_chunks:
|
||||||
|
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
||||||
|
|
||||||
|
|
||||||
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||||
|
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, key, value)
|
||||||
@@ -3,20 +3,15 @@
|
|||||||
Maintains in-memory state for all active Electron → backend WebSocket
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
connections. One connection per user (latest replaces previous).
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
The manager participates in two interaction patterns:
|
The manager handles the **tool-call round-trip** pattern:
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes the action →
|
||||||
1. **Tool-call round-trip** (bidirectional CRUD):
|
returns ``tool_result`` frame.
|
||||||
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
|
||||||
``tool_result`` frame.
|
|
||||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
receive the result dict from Electron.
|
receive the result dict from Electron.
|
||||||
|
|
||||||
2. **Agent-data streaming** (local directory agent runs):
|
This pattern is used by all tools (CRUD, file-system, etc.) via
|
||||||
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
``execute_on_client()`` in ``ws_context.py``.
|
||||||
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
|
||||||
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
|
||||||
a specific ``run_id`` so the agent runner can iterate frames.
|
|
||||||
|
|
||||||
The ``device_manager`` module-level singleton is imported by both the
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
device WS route and the agent runner.
|
device WS route and the agent runner.
|
||||||
@@ -42,8 +37,6 @@ class DeviceConnection:
|
|||||||
device_id: str
|
device_id: str
|
||||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
# Per-run queues for agent_data / agent_complete frames.
|
|
||||||
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceConnectionManager:
|
class DeviceConnectionManager:
|
||||||
@@ -153,31 +146,6 @@ class DeviceConnectionManager:
|
|||||||
if fut is not None and not fut.done():
|
if fut is not None and not fut.done():
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
|
|
||||||
# ── Agent-data queue ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def get_agent_data_queue(
|
|
||||||
self, user_id: str, run_id: str
|
|
||||||
) -> asyncio.Queue[dict | None]:
|
|
||||||
"""Return (creating if absent) the queue for *run_id* agent frames.
|
|
||||||
|
|
||||||
The agent runner reads from this queue. The device WS handler writes
|
|
||||||
to it. ``None`` is the sentinel that signals the stream is finished.
|
|
||||||
"""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"get_agent_data_queue: user {user_id!r} is not connected"
|
|
||||||
)
|
|
||||||
if run_id not in conn.agent_data_queues:
|
|
||||||
conn.agent_data_queues[run_id] = asyncio.Queue()
|
|
||||||
return conn.agent_data_queues[run_id]
|
|
||||||
|
|
||||||
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
|
||||||
"""Remove the queue for *run_id* once a run has completed."""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn:
|
|
||||||
conn.agent_data_queues.pop(run_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — import this everywhere.
|
# Module-level singleton — import this everywhere.
|
||||||
device_manager = DeviceConnectionManager()
|
device_manager = DeviceConnectionManager()
|
||||||
|
|||||||
@@ -1,222 +0,0 @@
|
|||||||
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.schemas import ExecutionPlan, PlanStep
|
|
||||||
|
|
||||||
|
|
||||||
# ── Prompt Template Registry ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateRegistry:
|
|
||||||
"""Server-side store mapping template IDs to prompt text.
|
|
||||||
|
|
||||||
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
|
||||||
The actual prompt text is resolved here on the server, keeping prompt IP
|
|
||||||
out of API responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._templates: dict[str, str] = {}
|
|
||||||
|
|
||||||
def register(self, template_id: str, prompt_text: str) -> None:
|
|
||||||
self._templates[template_id] = prompt_text
|
|
||||||
|
|
||||||
def get(self, template_id: str) -> str:
|
|
||||||
"""Resolve a template ID to its prompt text.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the template is not registered.
|
|
||||||
"""
|
|
||||||
text = self._templates.get(template_id)
|
|
||||||
if text is None:
|
|
||||||
raise KeyError(f"Template not found: {template_id!r}")
|
|
||||||
return text
|
|
||||||
|
|
||||||
def has(self, template_id: str) -> bool:
|
|
||||||
return template_id in self._templates
|
|
||||||
|
|
||||||
def list_ids(self) -> list[str]:
|
|
||||||
"""Return all registered template IDs (never the text)."""
|
|
||||||
return list(self._templates.keys())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plan Builder ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlanBuilder:
|
|
||||||
"""Fluent builder for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent: str) -> None:
|
|
||||||
self._agent = agent
|
|
||||||
self._steps: list[PlanStep] = []
|
|
||||||
|
|
||||||
# ── step adders ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def add_step(
|
|
||||||
self, action: str, params: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a generic action step with optional parameters."""
|
|
||||||
self._steps.append(PlanStep(action=action, variables=params))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_llm_step(
|
|
||||||
self, template_id: str, variables: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append an LLM step referencing a server-side template by ID."""
|
|
||||||
self._steps.append(
|
|
||||||
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a step whose input comes from the output of an earlier step."""
|
|
||||||
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
|
||||||
return self
|
|
||||||
|
|
||||||
# ── build ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def build(self) -> ExecutionPlan:
|
|
||||||
"""Validate step references and return the ``ExecutionPlan``.
|
|
||||||
|
|
||||||
Raises ``ValueError`` if any ``data_from_step`` references a
|
|
||||||
non-existent or future step index.
|
|
||||||
"""
|
|
||||||
for i, step in enumerate(self._steps):
|
|
||||||
if step.data_from_step is not None:
|
|
||||||
if not (0 <= step.data_from_step < i):
|
|
||||||
raise ValueError(
|
|
||||||
f"Step {i}: data_from_step={step.data_from_step} must "
|
|
||||||
f"reference a preceding step index in range 0..{i - 1}"
|
|
||||||
)
|
|
||||||
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PlanCache:
|
|
||||||
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
|
||||||
The cache also serves as a runtime memoisation layer so that repeated
|
|
||||||
identical intent classifications can skip re-building the plan.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, maxsize: int = 1000) -> None:
|
|
||||||
self._maxsize = maxsize
|
|
||||||
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
|
||||||
|
|
||||||
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
|
||||||
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
|
||||||
if key in self._cache:
|
|
||||||
del self._cache[key] # remove so re-insertion places it at the end
|
|
||||||
elif len(self._cache) >= self._maxsize:
|
|
||||||
self._cache.popitem(last=False) # evict least-recently-used
|
|
||||||
self._cache[key] = plan
|
|
||||||
|
|
||||||
def get_plan(self, key: str) -> ExecutionPlan | None:
|
|
||||||
"""Return the cached plan for *key*, or ``None`` if not present.
|
|
||||||
|
|
||||||
Accessing a plan marks it as most-recently used.
|
|
||||||
"""
|
|
||||||
if key not in self._cache:
|
|
||||||
return None
|
|
||||||
self._cache.move_to_end(key)
|
|
||||||
return self._cache[key]
|
|
||||||
|
|
||||||
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached plans (most-recently used last)."""
|
|
||||||
return list(self._cache.values())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Module-level singletons ───────────────────────────────────────────
|
|
||||||
|
|
||||||
template_registry = PromptTemplateRegistry()
|
|
||||||
plan_cache = PlanCache()
|
|
||||||
|
|
||||||
|
|
||||||
def _register_builtin_templates() -> None:
|
|
||||||
"""Register the built-in server-side prompt templates.
|
|
||||||
|
|
||||||
These strings never leave the server. Clients only receive the IDs.
|
|
||||||
"""
|
|
||||||
_tpls: dict[str, str] = {
|
|
||||||
"tpl_task_agent_default": (
|
|
||||||
"You are a task management assistant. Help the user create, update, "
|
|
||||||
"list, and track tasks. Use correct status values (todo, in_progress, "
|
|
||||||
"done) and priority values (high, medium, low) from the workspace model."
|
|
||||||
),
|
|
||||||
"tpl_timeline_agent_default": (
|
|
||||||
"You are a project timeline assistant. Help the user create and manage "
|
|
||||||
"milestone timelines on their projects. Every timeline requires a "
|
|
||||||
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
|
||||||
),
|
|
||||||
"tpl_project_agent_default": (
|
|
||||||
"You are a project management assistant. Help the user create, find, "
|
|
||||||
"update, and archive projects. Projects have a name, an optional client, "
|
|
||||||
"and a status of either active or archived."
|
|
||||||
),
|
|
||||||
"tpl_note_agent_default": (
|
|
||||||
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
|
||||||
"and delete Markdown notes. Notes can optionally be linked to a project."
|
|
||||||
),
|
|
||||||
"tpl_task_extract_from_project": (
|
|
||||||
"Extract all actionable tasks from the provided project context. "
|
|
||||||
"Return a structured list of tasks, each with a title, inferred priority "
|
|
||||||
"(high, medium, or low), suggested status (todo), and a due_date in "
|
|
||||||
"milliseconds where a deadline can be inferred."
|
|
||||||
),
|
|
||||||
"tpl_note_weekly_summary": (
|
|
||||||
"Generate a weekly project summary note from the provided workspace data. "
|
|
||||||
"Include: tasks completed this week, tasks due soon, active projects, "
|
|
||||||
"and upcoming timelines. Format the output as clean Markdown."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for tid, text in _tpls.items():
|
|
||||||
template_registry.register(tid, text)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_playbooks() -> None:
|
|
||||||
"""Pre-build and cache the built-in playbooks."""
|
|
||||||
playbooks: list[tuple[str, ExecutionPlan]] = [
|
|
||||||
(
|
|
||||||
"create_tasks_from_project",
|
|
||||||
ExecutionPlanBuilder("project_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_task_extract_from_project",
|
|
||||||
{"source": "project_context"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"generate_weekly_note",
|
|
||||||
ExecutionPlanBuilder("note_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_note_weekly_summary",
|
|
||||||
{"period": "last_7_days"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for key, plan in playbooks:
|
|
||||||
plan_cache.cache_plan(key, plan)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialise on module load
|
|
||||||
_register_builtin_templates()
|
|
||||||
_load_playbooks()
|
|
||||||
@@ -18,6 +18,7 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -32,6 +33,14 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
# Some provider responses include a plain dict in the `usage` field where a
|
||||||
|
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
|
|||||||
@@ -50,7 +50,13 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
async def enrich_context(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
@@ -65,9 +71,21 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
episodic = await self._load_episodic(user_id, fernet)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
len(core),
|
||||||
|
len(associative),
|
||||||
|
len(episodic),
|
||||||
|
len(proactive),
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"core_memory": core,
|
"core_memory": core,
|
||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
@@ -81,6 +99,7 @@ class MemoryMiddleware:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
response: str,
|
response: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
@@ -103,11 +122,19 @@ class MemoryMiddleware:
|
|||||||
self._db.add(row)
|
self._db.add(row)
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -133,10 +160,176 @@ class MemoryMiddleware:
|
|||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
key,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||||
|
"""Return core memory as editable blocks (label/value)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore)
|
||||||
|
.where(MemoryCore.user_id == user_id)
|
||||||
|
.order_by(MemoryCore.key.asc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[dict[str, str]] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append({"label": row.key, "value": plaintext})
|
||||||
|
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||||
|
"""Return a single core memory block value by label."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
||||||
|
return None
|
||||||
|
value = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||||
|
"""Delete a core memory block by label. Returns True if deleted."""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._db.delete(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||||
|
"""Append content to a core block, creating it if missing."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None:
|
||||||
|
await self.update_core(user_id, label, content)
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
||||||
|
return
|
||||||
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
||||||
|
|
||||||
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||||
|
"""Replace one exact string inside a core block. Returns False if not found."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None or old not in current:
|
||||||
|
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
||||||
|
return False
|
||||||
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||||
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=None,
|
||||||
|
entity_type=source,
|
||||||
|
entity_id=None,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search recall memory (episodic summaries) by keyword."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
# ── Private helpers ───────────────────────────────────────────────────────
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
@@ -148,6 +341,16 @@ class MemoryMiddleware:
|
|||||||
return None
|
return None
|
||||||
return Fernet(user.encryption_key.encode())
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||||
|
"""Load lightweight user debug fields for trace logs."""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
return {"tier": None}
|
||||||
|
return {
|
||||||
|
"tier": user.tier,
|
||||||
|
}
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
@@ -183,10 +386,17 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_episodic(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Fernet,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
if session_id:
|
||||||
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryEpisodic)
|
query
|
||||||
.where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
.limit(_EPISODIC_RECENT_N)
|
.limit(_EPISODIC_RECENT_N)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,210 +0,0 @@
|
|||||||
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, AsyncGenerator
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
from app.core.llm import get_router_llm
|
|
||||||
from app.core.agent_registry import registry as _default_registry
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
|
||||||
|
|
||||||
_FALLBACK_AGENT = "task_agent"
|
|
||||||
|
|
||||||
_CLASSIFY_SYSTEM = (
|
|
||||||
"You are an intent classifier. Given the user message and context, decide "
|
|
||||||
"which agent to route to.\n"
|
|
||||||
"Available agents: {agents}\n"
|
|
||||||
"Respond with just the agent name, nothing else."
|
|
||||||
)
|
|
||||||
|
|
||||||
_SYNTHESIZE_HUMAN = (
|
|
||||||
"Combine the following agent results into one coherent response.\n\n"
|
|
||||||
"Agent results:\n{results}\n\n"
|
|
||||||
"Original message: {message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_llm():
|
|
||||||
return get_router_llm()
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_intent(
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> str:
|
|
||||||
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
|
||||||
|
|
||||||
Falls back to ``task_agent`` when the registry is empty or the model
|
|
||||||
returns a name that is not registered.
|
|
||||||
"""
|
|
||||||
agents = reg.list_agents()
|
|
||||||
if not agents:
|
|
||||||
return _FALLBACK_AGENT
|
|
||||||
|
|
||||||
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
|
||||||
# Truncate context to keep the classification prompt short
|
|
||||||
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
|
||||||
|
|
||||||
llm = _make_llm()
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[SystemMessage(content=system), HumanMessage(content=human)]
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = str(response.content).strip().lower()
|
|
||||||
known = {a["name"] for a in agents}
|
|
||||||
return agent_name if agent_name in known else _FALLBACK_AGENT
|
|
||||||
|
|
||||||
|
|
||||||
async def route_single(
|
|
||||||
agent_name: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
|
||||||
response_text = await reg.call_agent(agent_name, message, context)
|
|
||||||
return ChatResponse(response=response_text)
|
|
||||||
|
|
||||||
|
|
||||||
async def route_pipeline(
|
|
||||||
agent_names: list[str],
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Execute agents sequentially; each agent receives previous results in context.
|
|
||||||
|
|
||||||
A final LLM synthesis call merges all results into one coherent response.
|
|
||||||
"""
|
|
||||||
previous_results: list[str] = []
|
|
||||||
|
|
||||||
for agent_name in agent_names:
|
|
||||||
ctx = {**context, "previous_results": list(previous_results)}
|
|
||||||
result = await reg.call_agent(agent_name, message, ctx)
|
|
||||||
previous_results.append(result)
|
|
||||||
|
|
||||||
results_str = "\n\n".join(
|
|
||||||
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
|
||||||
)
|
|
||||||
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
|
||||||
llm = _make_llm()
|
|
||||||
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
|
||||||
return ChatResponse(response=str(synthesis.content))
|
|
||||||
|
|
||||||
|
|
||||||
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
|
||||||
"""Build an ``ExecutionPlan`` for the resolved agent.
|
|
||||||
|
|
||||||
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
|
||||||
If a default template exists for the agent, an LLM step is emitted;
|
|
||||||
otherwise a plain ``handle`` action step is used.
|
|
||||||
"""
|
|
||||||
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
|
||||||
|
|
||||||
template_id = f"tpl_{agent_name}_default"
|
|
||||||
builder = ExecutionPlanBuilder(agent_name)
|
|
||||||
if template_registry.has(template_id):
|
|
||||||
builder.add_llm_step(template_id, {"message": message})
|
|
||||||
else:
|
|
||||||
builder.add_step("handle", {"message": message})
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> ChatResponse | ExecutionPlan:
|
|
||||||
"""Main orchestration entry point.
|
|
||||||
|
|
||||||
* Classifies the user's intent to select an agent.
|
|
||||||
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
|
||||||
``ChatResponse``.
|
|
||||||
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
|
||||||
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
|
|
||||||
if request.execution_mode == "direct":
|
|
||||||
return await route_single(agent_name, request.message, context, reg)
|
|
||||||
|
|
||||||
# plan mode — return plan, do not execute
|
|
||||||
return _build_plan(agent_name, request.message)
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_v3(
|
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> tuple[str, ChatAgent]:
|
|
||||||
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
|
|
||||||
|
|
||||||
Classifies intent and instantiates the matching agent. The caller is responsible
|
|
||||||
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
agent_name = await classify_intent(message, context, reg)
|
|
||||||
return agent_name, reg.get(agent_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_v3_stream(
|
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
agent_holder: list | None = None,
|
|
||||||
) -> AsyncGenerator[tuple[str, str], None]:
|
|
||||||
"""v3 streaming orchestration — yields (agent_name, token) pairs.
|
|
||||||
|
|
||||||
The first yield always carries the agent_name with an empty token so that
|
|
||||||
callers (e.g. FloatingFormatter) can detect the routing domain before any text
|
|
||||||
tokens arrive.
|
|
||||||
|
|
||||||
If *agent_holder* is provided (a list), the agent instance is appended so
|
|
||||||
callers can access ``agent.tool_results`` after the stream completes.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
agent_name = await classify_intent(message, context, reg)
|
|
||||||
agent = reg.get(agent_name)
|
|
||||||
if agent_holder is not None:
|
|
||||||
agent_holder.append(agent)
|
|
||||||
yield agent_name, "" # domain signal — no token yet
|
|
||||||
async for token in agent.handle_stream(message, context):
|
|
||||||
yield agent_name, token
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_stream(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming orchestration — yields plain text chunks only.
|
|
||||||
|
|
||||||
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
|
|
||||||
wrapping each chunk in a ``text_chunk`` frame and sending the final
|
|
||||||
``final`` frame once the generator is exhausted.
|
|
||||||
|
|
||||||
Agents do not yet support token-level streaming; the full response is
|
|
||||||
fetched first (which may involve multiple WS round-trips for tool calls),
|
|
||||||
then emitted in fixed-size chunks.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
response_text = await reg.call_agent(agent_name, request.message, context)
|
|
||||||
|
|
||||||
chunk_size = 50
|
|
||||||
for i in range(0, len(response_text), chunk_size):
|
|
||||||
yield response_text[i : i + chunk_size]
|
|
||||||
@@ -1,244 +1,47 @@
|
|||||||
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
|
"""Output formatter for deep-agent stream events."""
|
||||||
|
|
||||||
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
|
|
||||||
FloatingFormatter: produces floating_domain, stream_text, stream_end
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
WsFloatingDomain,
|
|
||||||
WsStreamBlock,
|
|
||||||
WsStreamEnd,
|
|
||||||
WsStreamStart,
|
|
||||||
WsStreamText,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
|
|
||||||
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
|
|
||||||
|
|
||||||
# Map agent name → floating domain
|
|
||||||
_AGENT_DOMAIN: dict[str, str] = {
|
|
||||||
"task_agent": "tasks",
|
|
||||||
"timeline_agent": "timelines",
|
|
||||||
"note_agent": "notes",
|
|
||||||
"project_agent": "projects",
|
|
||||||
}
|
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
|
|
||||||
|
|
||||||
|
|
||||||
class HomeFormatter:
|
class StreamFormatter:
|
||||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||||
|
|
||||||
The LLM is expected to output a newline-delimited sequence of JSON objects,
|
|
||||||
each with a ``type`` field:
|
|
||||||
- ``text`` → yields WsStreamText immediately (word-by-word)
|
|
||||||
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
|
|
||||||
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
|
|
||||||
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
|
|
||||||
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
|
|
||||||
|
|
||||||
Invalid or unknown blocks are logged and skipped — stream never crashes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
|
|
||||||
self.request_id = request_id
|
|
||||||
self.tool_results = tool_results
|
|
||||||
|
|
||||||
async def format(
|
|
||||||
self,
|
|
||||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
|
||||||
|
|
||||||
buffer = ""
|
|
||||||
async for _agent_name, token in token_stream:
|
|
||||||
if not token:
|
|
||||||
continue
|
|
||||||
buffer += token
|
|
||||||
# Flush any complete JSON objects from the buffer
|
|
||||||
async for frame in self._flush_complete_objects(buffer):
|
|
||||||
buffer = "" # reset after flush
|
|
||||||
yield frame
|
|
||||||
break # only one flush per iteration; rest accumulates
|
|
||||||
|
|
||||||
# Flush any remaining content
|
|
||||||
if buffer.strip():
|
|
||||||
async for frame in self._flush_complete_objects(buffer, final=True):
|
|
||||||
yield frame
|
|
||||||
|
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
|
||||||
|
|
||||||
async def _flush_complete_objects(
|
|
||||||
self, text: str, final: bool = False
|
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
|
||||||
"""Try to parse and yield all complete JSON objects from *text*.
|
|
||||||
|
|
||||||
Yields nothing if text is incomplete JSON (unless *final* is True,
|
|
||||||
in which case remaining text is emitted as plain stream_text).
|
|
||||||
"""
|
|
||||||
remaining = text.strip()
|
|
||||||
while remaining:
|
|
||||||
# Fast path: plain text (not JSON)
|
|
||||||
if not remaining.startswith("{"):
|
|
||||||
# Yield as plain text chunk
|
|
||||||
newline_idx = remaining.find("\n")
|
|
||||||
if newline_idx == -1:
|
|
||||||
if final:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
||||||
remaining = ""
|
|
||||||
else:
|
|
||||||
return # accumulate more
|
|
||||||
else:
|
|
||||||
line = remaining[:newline_idx].strip()
|
|
||||||
remaining = remaining[newline_idx + 1:].strip()
|
|
||||||
if line:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=line)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Try to decode a JSON object
|
|
||||||
try:
|
|
||||||
obj, end_idx = _try_parse_json(remaining)
|
|
||||||
except ValueError:
|
|
||||||
if final:
|
|
||||||
# Emit as raw text if we can't parse
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
||||||
remaining = ""
|
|
||||||
return
|
|
||||||
|
|
||||||
if obj is None:
|
|
||||||
if final:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
||||||
remaining = ""
|
|
||||||
return # incomplete — need more tokens
|
|
||||||
|
|
||||||
remaining = remaining[end_idx:].strip()
|
|
||||||
block_type = obj.get("type")
|
|
||||||
|
|
||||||
frame = self._dispatch_block(obj, block_type)
|
|
||||||
if frame is not None:
|
|
||||||
yield frame
|
|
||||||
|
|
||||||
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
|
|
||||||
if block_type == "text":
|
|
||||||
content = obj.get("content", "")
|
|
||||||
if content:
|
|
||||||
return WsStreamText(request_id=self.request_id, chunk=str(content))
|
|
||||||
return None
|
|
||||||
|
|
||||||
if block_type == "chart":
|
|
||||||
chart_type = obj.get("chartType")
|
|
||||||
if chart_type not in _VALID_CHART_TYPES:
|
|
||||||
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
|
|
||||||
return None
|
|
||||||
if not isinstance(obj.get("data"), list):
|
|
||||||
logger.warning("HomeFormatter: chart missing data array — skipping")
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="chart",
|
|
||||||
data=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if block_type == "entity_ref":
|
|
||||||
entity = obj.get("entity")
|
|
||||||
resolved = self._resolve_entity(entity)
|
|
||||||
if resolved is None:
|
|
||||||
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="entity_ref",
|
|
||||||
data={"entity": entity, "items": resolved},
|
|
||||||
)
|
|
||||||
|
|
||||||
if block_type == "table":
|
|
||||||
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
|
|
||||||
logger.warning("HomeFormatter: table missing headers/rows — skipping")
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="table",
|
|
||||||
data=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if block_type == "timeline":
|
|
||||||
if not isinstance(obj.get("timelines"), list):
|
|
||||||
logger.warning("HomeFormatter: timeline missing timelines — skipping")
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="timeline",
|
|
||||||
data=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
|
|
||||||
"""Find matching items in tool_results by entity type."""
|
|
||||||
if not entity:
|
|
||||||
return None
|
|
||||||
matches = [r for r in self.tool_results if r.get("entity") == entity]
|
|
||||||
return matches if matches else None
|
|
||||||
|
|
||||||
|
|
||||||
class FloatingFormatter:
|
|
||||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
|
||||||
|
|
||||||
Emits floating_domain immediately (from agent_name), then streams all tokens
|
|
||||||
as plain stream_text — no block parsing for floating context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
|
||||||
async def format(
|
async def format(
|
||||||
self,
|
self,
|
||||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
domain_sent = False
|
started = False
|
||||||
|
|
||||||
async for agent_name, token in token_stream:
|
async for event_type, data in event_stream:
|
||||||
if not domain_sent:
|
if event_type == "floating_domain":
|
||||||
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
|
if isinstance(data, dict):
|
||||||
yield WsFloatingDomain(
|
yield WsFloatingDomain(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
domain=domain, # type: ignore[arg-type]
|
domain=data,
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event_type != "token":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not started:
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
domain_sent = True
|
started = True
|
||||||
|
|
||||||
if token:
|
text = str(data or "")
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=token)
|
if text:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||||
|
|
||||||
|
if not started:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
|
|
||||||
|
|
||||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
|
|
||||||
"""Attempt to parse the first complete JSON object from *text*.
|
|
||||||
|
|
||||||
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
|
|
||||||
object is incomplete, and raises ``ValueError`` when text is not JSON.
|
|
||||||
"""
|
|
||||||
decoder = json.JSONDecoder()
|
|
||||||
try:
|
|
||||||
obj, end_idx = decoder.raw_decode(text)
|
|
||||||
if not isinstance(obj, dict):
|
|
||||||
raise ValueError("Expected JSON object")
|
|
||||||
return obj, end_idx
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
# Incomplete JSON — need more tokens
|
|
||||||
if "Unterminated" in str(exc) or exc.pos == len(text):
|
|
||||||
return None, 0
|
|
||||||
raise ValueError(str(exc)) from exc
|
|
||||||
|
|||||||
@@ -18,9 +18,8 @@ from app.config.settings import settings
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: initialise DB connection pool and agent registry
|
# Startup: ensure agent tool modules are loaded.
|
||||||
from app.core.agent_registry import registry # noqa: F401 — triggers module load
|
import app.agents # noqa: F401
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -51,18 +50,16 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
|
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
app.include_router(plans.router, prefix="/api/v1")
|
|
||||||
app.include_router(storage.router, prefix="/api/v1")
|
app.include_router(storage.router, prefix="/api/v1")
|
||||||
app.include_router(vectors.router, prefix="/api/v1")
|
app.include_router(vectors.router, prefix="/api/v1")
|
||||||
app.include_router(backup.router, prefix="/api/v1")
|
app.include_router(backup.router, prefix="/api/v1")
|
||||||
app.include_router(plugins.router, prefix="/api/v1")
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
app.include_router(agent_setup.router, prefix="/api/v1")
|
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
|||||||
184
app/schemas.py
184
app/schemas.py
@@ -41,41 +41,13 @@ class ChatContext(BaseModel):
|
|||||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class PlanAction(BaseModel):
|
|
||||||
type: Literal[
|
|
||||||
"create_record",
|
|
||||||
"update_record",
|
|
||||||
"delete_record",
|
|
||||||
"index_document",
|
|
||||||
"send_notification",
|
|
||||||
]
|
|
||||||
table: str | None = None
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
context: ChatContext = Field(default_factory=ChatContext)
|
context: ChatContext = Field(default_factory=ChatContext)
|
||||||
execution_mode: Literal["direct", "plan"] = "direct"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
response: str
|
response: str
|
||||||
actions: list[PlanAction] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plans ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class PlanStep(BaseModel):
|
|
||||||
action: str
|
|
||||||
prompt_template: str | None = None
|
|
||||||
variables: dict[str, Any] | None = None
|
|
||||||
data_from_step: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlan(BaseModel):
|
|
||||||
agent: str
|
|
||||||
steps: list[PlanStep] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Backup ───────────────────────────────────────────────────────────
|
# ── Backup ───────────────────────────────────────────────────────────
|
||||||
@@ -170,21 +142,21 @@ class WsFrameType(str, Enum):
|
|||||||
tool_result = "tool_result"
|
tool_result = "tool_result"
|
||||||
final = "final"
|
final = "final"
|
||||||
ping = "ping"
|
ping = "ping"
|
||||||
agent_run = "agent_run"
|
|
||||||
agent_data = "agent_data"
|
|
||||||
agent_complete = "agent_complete"
|
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
floating_request = "floating_request"
|
floating_request = "floating_request"
|
||||||
stream_start = "stream_start"
|
stream_start = "stream_start"
|
||||||
stream_text = "stream_text"
|
stream_text = "stream_text"
|
||||||
stream_block = "stream_block"
|
|
||||||
stream_end = "stream_end"
|
stream_end = "stream_end"
|
||||||
floating_domain = "floating_domain"
|
floating_domain = "floating_domain"
|
||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
|
# ── v4 journey frame types ────────────────────────────────────────
|
||||||
|
journey_start = "journey_start"
|
||||||
|
journey_message = "journey_message"
|
||||||
|
journey_reply = "journey_reply"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -237,31 +209,6 @@ class WsDeviceHello(BaseModel):
|
|||||||
agent_ids: list[str] = Field(default_factory=list)
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class WsAgentRun(BaseModel):
|
|
||||||
"""Server → Client: trigger an agent run on the connected device."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
|
||||||
run_id: str
|
|
||||||
agent_id: str
|
|
||||||
config: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentData(BaseModel):
|
|
||||||
"""Client → Server: files read by the local agent."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
|
||||||
run_id: str
|
|
||||||
files: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentComplete(BaseModel):
|
|
||||||
"""Client → Server: Electron signals it has finished reading files."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
|
||||||
run_id: str
|
|
||||||
files_read: int
|
|
||||||
errors: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
@@ -303,21 +250,19 @@ class WsStreamText(BaseModel):
|
|||||||
chunk: str
|
chunk: str
|
||||||
|
|
||||||
|
|
||||||
class WsStreamBlock(BaseModel):
|
|
||||||
"""Server → Client: structured block (chart, table, entity, timeline)."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.stream_block] = WsFrameType.stream_block
|
|
||||||
request_id: str
|
|
||||||
block_type: Literal["chart", "entity_ref", "table", "timeline"]
|
|
||||||
data: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class WsStreamEnd(BaseModel):
|
class WsStreamEnd(BaseModel):
|
||||||
"""Server → Client: signals end of a streaming response."""
|
"""Server → Client: signals end of a streaming response."""
|
||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
class WsDomain(BaseModel):
|
||||||
|
"""Structured floating domain payload for UI routing decisions."""
|
||||||
|
|
||||||
|
type: Literal["task", "timeline", "project", "node"]
|
||||||
|
id: str | None = None
|
||||||
|
section: Literal["task", "timeline", "note"] | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class WsFloatingDomain(BaseModel):
|
||||||
@@ -325,7 +270,7 @@ class WsFloatingDomain(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
request_id: str
|
request_id: str
|
||||||
domain: Literal["tasks", "timelines", "notes", "projects"]
|
domain: WsDomain
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
@@ -334,84 +279,28 @@ class AgentCatalogItem(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
config_schema: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local Agent Config ────────────────────────────────────────────────
|
class AgentCreationCheckRequest(BaseModel):
|
||||||
|
active_agents: int = Field(ge=0, default=0)
|
||||||
class LocalAgentConfigCreate(BaseModel):
|
|
||||||
name: str
|
|
||||||
device_id: str
|
|
||||||
directory_paths: list[str]
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
file_extensions: list[str]
|
|
||||||
schedule_cron: str
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigUpdate(BaseModel):
|
class AgentCreationCheckResponse(BaseModel):
|
||||||
name: str | None = None
|
allowed: bool
|
||||||
device_id: str | None = None
|
tier: BillingTier
|
||||||
directory_paths: list[str] | None = None
|
active_agents: int
|
||||||
data_types: list[str] | None = None
|
limit: int
|
||||||
prompt_template: str | None = None
|
|
||||||
file_extensions: list[str] | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigResponse(BaseModel):
|
class AgentTriggerRequest(BaseModel):
|
||||||
id: str
|
directory: str = Field(min_length=1)
|
||||||
name: str
|
device_id: str = Field(default="")
|
||||||
device_id: str
|
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||||
directory_paths: list[str]
|
what_to_extract: list[str] = Field(min_length=1)
|
||||||
data_types: list[str]
|
actions_by_type: dict[str, list[str]] | None = None
|
||||||
prompt_template: str
|
batch_interval: str = Field(min_length=1)
|
||||||
file_extensions: list[str]
|
custom_agent_prompt: str = Field(min_length=1)
|
||||||
schedule_cron: str
|
active_agents: int = Field(ge=0, default=0)
|
||||||
enabled: bool
|
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Agent Config ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class CloudAgentConfigCreate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
oauth_token_encrypted: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigUpdate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"] | None = None
|
|
||||||
name: str | None = None
|
|
||||||
data_types: list[str] | None = None
|
|
||||||
prompt_template: str | None = None
|
|
||||||
oauth_token_encrypted: str | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigResponse(BaseModel):
|
|
||||||
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None
|
|
||||||
enabled: bool
|
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
@@ -430,18 +319,3 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
class JourneyStartRequest(BaseModel):
|
|
||||||
agent_type: Literal["local", "cloud"]
|
|
||||||
agent_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyMessageRequest(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyResponse(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
done: bool
|
|
||||||
prompt_template: str | None = None
|
|
||||||
|
|||||||
879
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
879
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
@@ -0,0 +1,879 @@
|
|||||||
|
# Adiuva — Architettura Microservizi
|
||||||
|
|
||||||
|
## Panoramica
|
||||||
|
|
||||||
|
Il monolite attuale viene suddiviso in **5 servizi** + un **API Gateway**, orchestrati con Docker Compose e raggiungibili tramite dominio su Cloudflare.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐
|
||||||
|
│ Cloudflare │
|
||||||
|
│ (DNS + CDN) │
|
||||||
|
└──────┬───────┘
|
||||||
|
│ HTTPS / WSS
|
||||||
|
┌──────▼───────┐
|
||||||
|
│ Traefik │
|
||||||
|
│ API Gateway │
|
||||||
|
│ (routing, │
|
||||||
|
│ TLS term.) │
|
||||||
|
└──────┬───────┘
|
||||||
|
│
|
||||||
|
┌──────────┬───────────┼───────────┬──────────┐
|
||||||
|
│ │ │ │ │
|
||||||
|
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐ ┌───▼─────┐
|
||||||
|
│ Auth │ │ Chat │ │ Storage │ │Billing │ │ Plugins │
|
||||||
|
│ Service │ │Service│ │ Service │ │Service │ │ Service │
|
||||||
|
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘ └───┬─────┘
|
||||||
|
│ │ │ │ │
|
||||||
|
┌─────▼──────────▼──────────▼───────────▼──────────▼─────┐
|
||||||
|
│ Infrastruttura │
|
||||||
|
│ PostgreSQL │ Redis │ MinIO (S3) │ Qdrant │ (Pinecone) │
|
||||||
|
└────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Suddivisione dei Servizi
|
||||||
|
|
||||||
|
### 1.1 Auth Service (`auth-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/auth/register` | POST |
|
||||||
|
| `/api/v1/auth/login` | POST |
|
||||||
|
| `/api/v1/auth/refresh` | POST |
|
||||||
|
| `/api/v1/auth/me` | GET / PUT |
|
||||||
|
|
||||||
|
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
|
||||||
|
|
||||||
|
**Modifica chiave — JWT con RS256**:
|
||||||
|
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
|
||||||
|
- L'Auth Service firma i JWT con la **chiave privata**.
|
||||||
|
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
|
||||||
|
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# auth-service/app/auth/jwt.py
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PRIVATE_KEY = ... # Da env/secret
|
||||||
|
PUBLIC_KEY = ... # Derivata o da env
|
||||||
|
|
||||||
|
def create_access_token(user_id: str, tier: str) -> str:
|
||||||
|
return jwt.encode(
|
||||||
|
{"sub": user_id, "tier": tier, "exp": ...},
|
||||||
|
PRIVATE_KEY,
|
||||||
|
algorithm="RS256",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/auth.py (usato da tutti gli altri servizi)
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
|
||||||
|
|
||||||
|
def verify_token(token: str) -> dict:
|
||||||
|
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
||||||
|
```
|
||||||
|
|
||||||
|
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.2 Chat Service (`chat-service`) ⭐ Core
|
||||||
|
|
||||||
|
**Responsabilità**: WebSocket device, home chat, floating chat, agent runner, memory middleware, agent setup journeys.
|
||||||
|
|
||||||
|
| Endpoint originale | Tipo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/ws/device` | WebSocket |
|
||||||
|
| `/api/v1/chat` | POST (REST fallback) |
|
||||||
|
| `/api/v1/agents/catalog` | GET |
|
||||||
|
| `/api/v1/agents/can-create` | POST |
|
||||||
|
| `/api/v1/agents/trigger` | POST |
|
||||||
|
|
||||||
|
**Moduli inclusi**: `deep_agent`, `agent_runner`, `agent_registry`, `memory_middleware`, `ws_context`, `device_manager`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`, `filesystem_agent`).
|
||||||
|
|
||||||
|
**Questa è la bestia che deve scalare orizzontalmente** — è il servizio più CPU/memory intensive (LLM calls, tool loops, WebSocket persistenti).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.3 Storage Service (`storage-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: CRUD record crittografati su S3, vector operations, backup.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/storage/records` | POST / GET |
|
||||||
|
| `/api/v1/storage/records/{id}` | GET / PUT / DELETE |
|
||||||
|
| `/api/v1/vectors/upsert` | POST |
|
||||||
|
| `/api/v1/vectors/search` | POST |
|
||||||
|
| `/api/v1/vectors/embed` | POST |
|
||||||
|
| `/api/v1/vectors` | DELETE |
|
||||||
|
| `/api/v1/backup` | PUT / GET / DELETE |
|
||||||
|
| `/api/v1/backup/history` | GET |
|
||||||
|
|
||||||
|
**Scaling**: 2–3 repliche. I/O bound (S3, Qdrant). Stateless.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.4 Billing Service (`billing-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: Stripe checkout, webhook, subscription management, tier enforcement.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/billing/checkout` | POST |
|
||||||
|
| `/api/v1/billing/webhook` | POST |
|
||||||
|
| `/api/v1/billing/subscription` | GET / DELETE |
|
||||||
|
|
||||||
|
**Database**: Tabelle `subscriptions` (schema `billing`).
|
||||||
|
|
||||||
|
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users (oppure i servizi leggono il tier direttamente dal JWT, aggiornato al prossimo refresh).
|
||||||
|
|
||||||
|
**Scaling**: 1 replica sufficiente. Basso traffico.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.5 Plugin Service (`plugin-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: Marketplace, installazione plugin, revenue split.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/plugins` | GET |
|
||||||
|
| `/api/v1/plugins/{id}` | GET |
|
||||||
|
| `/api/v1/plugins/{id}/install` | POST / DELETE |
|
||||||
|
|
||||||
|
**Database**: Tabelle `plugins`, `plugin_installations`, `revenue_events`.
|
||||||
|
|
||||||
|
**Scaling**: 1 replica. Basso traffico.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. WebSocket con Scaling Orizzontale — Il Problema Chiave
|
||||||
|
|
||||||
|
### Il problema attuale
|
||||||
|
|
||||||
|
`DeviceConnectionManager` è un **singleton in-memory**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DeviceConnectionManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
|
||||||
|
```
|
||||||
|
|
||||||
|
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
|
||||||
|
|
||||||
|
### La soluzione: Redis Pub/Sub + Registry
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ Redis │
|
||||||
|
│ │
|
||||||
|
│ Hash: ws:connections │
|
||||||
|
│ user_123 → instance_A │
|
||||||
|
│ user_456 → instance_B │
|
||||||
|
│ │
|
||||||
|
│ Pub/Sub channels: │
|
||||||
|
│ tool_call:{user_id} → tool call payloads │
|
||||||
|
│ tool_result:{call_id} → tool result payloads │
|
||||||
|
│ stream:{user_id} → text_chunk streaming │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
|
||||||
|
┌───────────────────────┐ ┌───────────────────────┐
|
||||||
|
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
|
||||||
|
│ tool_call:user_123│ │ → user_123 è su A │
|
||||||
|
│ │ │ │
|
||||||
|
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
|
||||||
|
│ da Redis channel │ │ tool_call:user_123 │
|
||||||
|
│ │ │ {id, action, ...} │
|
||||||
|
│ 3. Invia al device │ │ │
|
||||||
|
│ via WS │ │ 4. SUBSCRIBE │
|
||||||
|
│ │ │ tool_result:{id} │
|
||||||
|
│ 4. Device risponde │ │ │
|
||||||
|
│ tool_result │──────────│► 5. Riceve risultato │
|
||||||
|
│ │ │ │
|
||||||
|
│ 5. PUBLISH │ │ │
|
||||||
|
│ tool_result:{id} │ │ │
|
||||||
|
└───────────────────────┘ └───────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Implementazione: `RedisDeviceManager`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# chat-service/app/core/device_manager.py
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocalConnection:
|
||||||
|
ws: WebSocket
|
||||||
|
device_id: str
|
||||||
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDeviceManager:
|
||||||
|
"""Device manager backed by Redis for cross-instance communication."""
|
||||||
|
|
||||||
|
def __init__(self, redis_url: str = "redis://redis:6379"):
|
||||||
|
self._redis = aioredis.from_url(redis_url)
|
||||||
|
self._pubsub = self._redis.pubsub()
|
||||||
|
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
|
||||||
|
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Avvia il listener Redis per tool_call in arrivo."""
|
||||||
|
asyncio.create_task(self._listen_tool_calls())
|
||||||
|
|
||||||
|
# ── Registrazione ──
|
||||||
|
|
||||||
|
async def register(self, user_id: str, device_id: str, ws: WebSocket):
|
||||||
|
# Registra localmente
|
||||||
|
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
|
||||||
|
# Registra in Redis quale istanza ha la connessione
|
||||||
|
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
|
||||||
|
# Sottoscrivi ai tool_call per questo utente
|
||||||
|
await self._pubsub.subscribe(f"tool_call:{user_id}")
|
||||||
|
|
||||||
|
async def unregister(self, user_id: str):
|
||||||
|
conn = self._local.pop(user_id, None)
|
||||||
|
if conn:
|
||||||
|
for fut in conn.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
await self._redis.hdel("ws:connections", user_id)
|
||||||
|
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
|
||||||
|
|
||||||
|
# ── Presenza ──
|
||||||
|
|
||||||
|
async def is_online(self, user_id: str) -> bool:
|
||||||
|
return await self._redis.hexists("ws:connections", user_id)
|
||||||
|
|
||||||
|
# ── Tool-call round-trip (cross-instance) ──
|
||||||
|
|
||||||
|
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Invia un tool_call al device dell'utente.
|
||||||
|
Funziona sia che la WS sia locale che su un'altra istanza.
|
||||||
|
"""
|
||||||
|
call_id = payload["id"]
|
||||||
|
|
||||||
|
# Caso 1: connessione locale → invio diretto
|
||||||
|
if user_id in self._local:
|
||||||
|
conn = self._local[user_id]
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut: asyncio.Future[dict] = loop.create_future()
|
||||||
|
conn.pending_calls[call_id] = fut
|
||||||
|
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
|
||||||
|
return await asyncio.wait_for(fut, timeout=30.0)
|
||||||
|
|
||||||
|
# Caso 2: connessione remota → Redis pub/sub
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut = loop.create_future()
|
||||||
|
self._remote_futures[call_id] = fut
|
||||||
|
|
||||||
|
# Sottoscrivi al canale di risposta
|
||||||
|
result_channel = f"tool_result:{call_id}"
|
||||||
|
await self._pubsub.subscribe(result_channel)
|
||||||
|
|
||||||
|
# Pubblica il tool_call
|
||||||
|
await self._redis.publish(
|
||||||
|
f"tool_call:{user_id}",
|
||||||
|
json.dumps(payload),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(fut, timeout=30.0)
|
||||||
|
finally:
|
||||||
|
self._remote_futures.pop(call_id, None)
|
||||||
|
await self._pubsub.unsubscribe(result_channel)
|
||||||
|
|
||||||
|
# ── Risoluzione tool_result (da WS locale) ──
|
||||||
|
|
||||||
|
def resolve_local(self, user_id: str, call_id: str, result: dict):
|
||||||
|
conn = self._local.get(user_id)
|
||||||
|
if conn:
|
||||||
|
fut = conn.pending_calls.pop(call_id, None)
|
||||||
|
if fut and not fut.done():
|
||||||
|
fut.set_result(result)
|
||||||
|
|
||||||
|
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
|
||||||
|
"""Chiamato quando il device locale invia un tool_result."""
|
||||||
|
self.resolve_local(user_id, call_id, result)
|
||||||
|
# Pubblica anche su Redis per l'istanza remota che aspetta
|
||||||
|
await self._redis.publish(
|
||||||
|
f"tool_result:{call_id}",
|
||||||
|
json.dumps(result),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Listener Redis ──
|
||||||
|
|
||||||
|
async def _listen_tool_calls(self):
|
||||||
|
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
|
||||||
|
async for message in self._pubsub.listen():
|
||||||
|
if message["type"] != "message":
|
||||||
|
continue
|
||||||
|
channel = message["channel"]
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
|
||||||
|
data = json.loads(message["data"])
|
||||||
|
|
||||||
|
if channel.startswith("tool_call:"):
|
||||||
|
# Un'altra istanza vuole che inviamo un tool_call al nostro device
|
||||||
|
user_id = channel.split(":", 1)[1]
|
||||||
|
conn = self._local.get(user_id)
|
||||||
|
if conn:
|
||||||
|
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
|
||||||
|
|
||||||
|
elif channel.startswith("tool_result:"):
|
||||||
|
# Risposta a un tool_call che abbiamo inviato tramite Redis
|
||||||
|
call_id = channel.split(":", 1)[1]
|
||||||
|
fut = self._remote_futures.pop(call_id, None)
|
||||||
|
if fut and not fut.done():
|
||||||
|
fut.set_result(data)
|
||||||
|
|
||||||
|
# ── Stream cross-instance ──
|
||||||
|
|
||||||
|
async def publish_stream_chunk(self, user_id: str, chunk: dict):
|
||||||
|
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
|
||||||
|
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Struttura Directory Proposta
|
||||||
|
|
||||||
|
```
|
||||||
|
adiuva-api/
|
||||||
|
├── docker-compose.yml # Orchestrazione completa
|
||||||
|
├── docker-compose.dev.yml # Override per sviluppo locale
|
||||||
|
├── shared/ # Codice condiviso (montato come volume)
|
||||||
|
│ ├── auth.py # JWT verification (chiave pubblica)
|
||||||
|
│ ├── schemas.py # Pydantic schemas condivisi
|
||||||
|
│ ├── middleware/
|
||||||
|
│ │ ├── rate_limit.py
|
||||||
|
│ │ └── sanitizer.py
|
||||||
|
│ └── models/
|
||||||
|
│ └── base.py # SQLAlchemy base condivisa
|
||||||
|
│
|
||||||
|
├── auth-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # users, refresh_tokens
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ └── auth.py
|
||||||
|
│ └── services/
|
||||||
|
│ ├── jwt_service.py # RS256 signing
|
||||||
|
│ └── user_service.py
|
||||||
|
│
|
||||||
|
├── chat-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # agent_run_logs, memory_*
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── device_ws.py
|
||||||
|
│ │ ├── chat.py
|
||||||
|
│ │ └── agents.py
|
||||||
|
│ ├── core/
|
||||||
|
│ │ ├── device_manager.py # RedisDeviceManager
|
||||||
|
│ │ ├── deep_agent.py
|
||||||
|
│ │ ├── agent_runner.py
|
||||||
|
│ │ ├── agent_registry.py
|
||||||
|
│ │ ├── memory_middleware.py
|
||||||
|
│ │ ├── ws_context.py
|
||||||
|
│ │ ├── output_formatter.py
|
||||||
|
│ │ └── llm.py
|
||||||
|
│ └── agents/
|
||||||
|
│ ├── task_agent.py
|
||||||
|
│ ├── project_agent.py
|
||||||
|
│ ├── note_agent.py
|
||||||
|
│ ├── timeline_agent.py
|
||||||
|
│ └── filesystem_agent.py
|
||||||
|
│
|
||||||
|
├── storage-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # storage_records, backup_metadata
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── storage.py
|
||||||
|
│ │ ├── vectors.py
|
||||||
|
│ │ └── backup.py
|
||||||
|
│ └── services/
|
||||||
|
│ ├── blob_store.py
|
||||||
|
│ └── vector_store.py
|
||||||
|
│
|
||||||
|
├── billing-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # subscriptions
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ └── billing.py
|
||||||
|
│ └── services/
|
||||||
|
│ ├── stripe_service.py
|
||||||
|
│ └── tier_manager.py
|
||||||
|
│
|
||||||
|
├── plugin-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # plugins, installations, revenue
|
||||||
|
│ └── routes/
|
||||||
|
│ └── plugins.py
|
||||||
|
│
|
||||||
|
└── infra/
|
||||||
|
├── traefik/
|
||||||
|
│ └── traefik.yml
|
||||||
|
└── alembic/ # Migrazioni condivise o per-servizio
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Docker Compose — Configurazione Completa
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# docker-compose.yml
|
||||||
|
|
||||||
|
services:
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# API Gateway
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
traefik:
|
||||||
|
image: traefik:v3.2
|
||||||
|
command:
|
||||||
|
- "--api.insecure=true"
|
||||||
|
- "--providers.docker=true"
|
||||||
|
- "--providers.docker.exposedbydefault=false"
|
||||||
|
- "--entrypoints.web.address=:80"
|
||||||
|
- "--entrypoints.websecure.address=:443"
|
||||||
|
# Cloudflare gestisce TLS, Traefik riceve HTTP dal proxy
|
||||||
|
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
||||||
|
ports:
|
||||||
|
- "80:80"
|
||||||
|
- "443:443"
|
||||||
|
- "8080:8080" # Dashboard Traefik
|
||||||
|
volumes:
|
||||||
|
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Auth Service (2 repliche)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
auth-service:
|
||||||
|
build: ./auth-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
|
||||||
|
SERVICE_NAME: auth
|
||||||
|
secrets:
|
||||||
|
- jwt_private_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
|
||||||
|
- "traefik.http.services.auth.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Chat Service (scalabile, N repliche)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
chat-service:
|
||||||
|
build: ./chat-service
|
||||||
|
deploy:
|
||||||
|
replicas: 3
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: chat
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
# REST routes
|
||||||
|
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`) || PathPrefix(`/api/v1/agents`)"
|
||||||
|
- "traefik.http.services.chat.loadbalancer.server.port=8000"
|
||||||
|
# WebSocket route con sticky session
|
||||||
|
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
|
||||||
|
- "traefik.http.routers.ws.service=chat-ws"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Storage Service (2 repliche)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
storage-service:
|
||||||
|
build: ./storage-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: storage
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.storage.rule=PathPrefix(`/api/v1/storage`) || PathPrefix(`/api/v1/vectors`) || PathPrefix(`/api/v1/backup`)"
|
||||||
|
- "traefik.http.services.storage.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Billing Service (1 replica)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
billing-service:
|
||||||
|
build: ./billing-service
|
||||||
|
deploy:
|
||||||
|
replicas: 1
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: billing
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
|
||||||
|
- "traefik.http.services.billing.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Plugin Service (1 replica)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
plugin-service:
|
||||||
|
build: ./plugin-service
|
||||||
|
deploy:
|
||||||
|
replicas: 1
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: plugins
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.plugins.rule=PathPrefix(`/api/v1/plugins`)"
|
||||||
|
- "traefik.http.services.plugins.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Infrastruttura
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
db:
|
||||||
|
image: pgvector/pgvector:pg16
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: adiuva
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
minio:
|
||||||
|
image: minio/minio:latest
|
||||||
|
command: server /data --console-address ":9001"
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
- "9001:9001"
|
||||||
|
environment:
|
||||||
|
MINIO_ROOT_USER: minioadmin
|
||||||
|
MINIO_ROOT_PASSWORD: minioadmin
|
||||||
|
volumes:
|
||||||
|
- minio_data:/data
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
qdrant:
|
||||||
|
image: qdrant/qdrant:latest
|
||||||
|
volumes:
|
||||||
|
- qdrant_data:/qdrant/storage
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
secrets:
|
||||||
|
jwt_private_key:
|
||||||
|
file: ./infra/keys/jwt_private.pem
|
||||||
|
jwt_public_key:
|
||||||
|
file: ./infra/keys/jwt_public.pem
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
redis_data:
|
||||||
|
minio_data:
|
||||||
|
qdrant_data:
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Configurazione Cloudflare + VPS
|
||||||
|
|
||||||
|
### 5.1 DNS
|
||||||
|
|
||||||
|
```
|
||||||
|
api.tuodominio.com → A record → IP del VPS
|
||||||
|
→ Proxy: ON (orange cloud)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 Cloudflare Settings
|
||||||
|
|
||||||
|
| Setting | Valore | Motivo |
|
||||||
|
|---------|--------|--------|
|
||||||
|
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
|
||||||
|
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
|
||||||
|
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
|
||||||
|
| Under Attack Mode | Off (attivare se necessario) | |
|
||||||
|
|
||||||
|
### 5.3 TLS sul VPS
|
||||||
|
|
||||||
|
Due opzioni:
|
||||||
|
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
|
||||||
|
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# traefik.yml — con Cloudflare Origin Certificate
|
||||||
|
entryPoints:
|
||||||
|
websecure:
|
||||||
|
address: ":443"
|
||||||
|
|
||||||
|
tls:
|
||||||
|
certificates:
|
||||||
|
- certFile: /certs/origin.pem
|
||||||
|
keyFile: /certs/origin-key.pem
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.4 Rete VPS
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
|
||||||
|
# https://www.cloudflare.com/ips/
|
||||||
|
ufw default deny incoming
|
||||||
|
ufw allow from 173.245.48.0/20 to any port 443
|
||||||
|
ufw allow from 103.21.244.0/22 to any port 443
|
||||||
|
# ... (tutti gli IP range di Cloudflare)
|
||||||
|
ufw allow ssh
|
||||||
|
ufw enable
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Comunicazione Inter-Servizio
|
||||||
|
|
||||||
|
### 6.1 Pattern: Event Bus via Redis Pub/Sub
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┐ tier_changed:user_123 ┌──────────┐
|
||||||
|
│ Billing │ ────────────────────────► │ Auth │
|
||||||
|
│ Service │ │ Service │
|
||||||
|
└──────────┘ └──────────┘
|
||||||
|
|
||||||
|
┌──────────┐ agent_triggered:user_123 ┌──────────┐
|
||||||
|
│ Chat │ ◄──────────────────────── │ Any │
|
||||||
|
│ Service │ │ Service │
|
||||||
|
└──────────┘ └──────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 Pattern: HTTP Sincrono (per query semplici)
|
||||||
|
|
||||||
|
Il Chat Service può avere bisogno del tier dell'utente per il rate-limiting degli agent. Due strategie:
|
||||||
|
|
||||||
|
- **Strategia A (preferita)**: Il tier è nel JWT. All'aggiornamento, il Billing Service forza token refresh invalidando i vecchi token su Redis.
|
||||||
|
- **Strategia B**: Il Chat Service chiama `http://auth-service:8000/internal/user/{id}/tier` (rete Docker interna, non esposta).
|
||||||
|
|
||||||
|
### 6.3 Health Checks e Service Discovery
|
||||||
|
|
||||||
|
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
|
||||||
|
- **Redis pub/sub** (eventi asincroni)
|
||||||
|
- **Redis hash** (stato condiviso, es. `ws:connections`)
|
||||||
|
- **PostgreSQL** (dati persistenti condivisi)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Piano di Migrazione Incrementale
|
||||||
|
|
||||||
|
### Fase 1 — Preparazione (senza rompere nulla)
|
||||||
|
1. Aggiungere Redis al `docker-compose.yml` attuale
|
||||||
|
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi)
|
||||||
|
3. Implementare `RedisDeviceManager` come drop-in replacement
|
||||||
|
4. Estrarre `shared/` con auth verification, schemas, middleware
|
||||||
|
|
||||||
|
### Fase 2 — Primo split: Auth Service
|
||||||
|
1. Estrarre `auth.py` routes + models in `auth-service/`
|
||||||
|
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
|
||||||
|
3. Aggiornare Traefik per routare `/api/v1/auth/*` al nuovo servizio
|
||||||
|
4. Il monolite continua a servire tutto il resto
|
||||||
|
|
||||||
|
### Fase 3 — Storage + Billing + Plugins
|
||||||
|
1. Servizi stateless e senza WebSocket → facili da estrarre
|
||||||
|
2. Estrarre uno alla volta, testare, routare via Traefik
|
||||||
|
3. Il monolite diventa sempre più magro
|
||||||
|
|
||||||
|
### Fase 4 — Chat Service (il più delicato)
|
||||||
|
1. Il monolite residuo **diventa** il Chat Service
|
||||||
|
2. Rimuovere i route migrati, tenere solo WS + chat + agents
|
||||||
|
3. Testare lo scaling a 2+ istanze con `RedisDeviceManager`
|
||||||
|
4. Verificare tool-call cross-instance
|
||||||
|
|
||||||
|
### Fase 5 — Cleanup
|
||||||
|
1. Rimuovere il monolite originale
|
||||||
|
2. CI/CD pipeline per build/push separati
|
||||||
|
3. Monitoring (Prometheus + Grafana) per ogni servizio
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Rate Limiting Distribuito
|
||||||
|
|
||||||
|
Il middleware attuale usa un contatore in-memory per il rate-limiting. Con i microservizi:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/middleware/rate_limit.py
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
class DistributedRateLimiter:
|
||||||
|
def __init__(self, redis: aioredis.Redis):
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
async def check(self, user_id: str, tier: str) -> bool:
|
||||||
|
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
|
||||||
|
max_req = limits.get(tier, 20)
|
||||||
|
key = f"rate:{user_id}"
|
||||||
|
|
||||||
|
pipe = self._redis.pipeline()
|
||||||
|
pipe.incr(key)
|
||||||
|
pipe.expire(key, 60) # Finestra di 60 secondi
|
||||||
|
count, _ = await pipe.execute()
|
||||||
|
|
||||||
|
return count <= max_req
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Monitoraggio e Logging
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Aggiungere al docker-compose.yml
|
||||||
|
|
||||||
|
prometheus:
|
||||||
|
image: prom/prometheus:latest
|
||||||
|
volumes:
|
||||||
|
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
grafana:
|
||||||
|
image: grafana/grafana:latest
|
||||||
|
ports:
|
||||||
|
- "3000:3000"
|
||||||
|
volumes:
|
||||||
|
- grafana_data:/var/lib/grafana
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
loki:
|
||||||
|
image: grafana/loki:latest
|
||||||
|
restart: unless-stopped
|
||||||
|
```
|
||||||
|
|
||||||
|
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Sizing VPS Minimo Consigliato
|
||||||
|
|
||||||
|
| Componente | CPU | RAM | Note |
|
||||||
|
|---|---|---|---|
|
||||||
|
| Traefik | 0.25 | 128MB | |
|
||||||
|
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | |
|
||||||
|
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | Il più pesante (LLM calls) |
|
||||||
|
| Storage Service ×2 | 0.5 ×2 | 256MB ×2 | I/O bound |
|
||||||
|
| Billing Service | 0.25 | 128MB | |
|
||||||
|
| Plugin Service | 0.25 | 128MB | |
|
||||||
|
| PostgreSQL | 1.0 | 1GB | |
|
||||||
|
| Redis | 0.25 | 256MB | |
|
||||||
|
| Qdrant | 0.5 | 512MB | |
|
||||||
|
| MinIO | 0.25 | 256MB | |
|
||||||
|
| **Totale** | **~6 vCPU** | **~5.5 GB** | |
|
||||||
|
|
||||||
|
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Riepilogo Decisioni Architetturali
|
||||||
|
|
||||||
|
| Decisione | Scelta | Motivazione |
|
||||||
|
|---|---|---|
|
||||||
|
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
|
||||||
|
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
|
||||||
|
| WebSocket scaling | Redis pub/sub + registry | Cross-instance tool-call routing |
|
||||||
|
| Rate limiting | Redis contatori | Distribuito, sliding window |
|
||||||
|
| Service communication | Redis pub/sub + HTTP interno | Asincrono per eventi, sincrono per query |
|
||||||
|
| Database | PostgreSQL condiviso (un DB, schema separation opzionale) | Semplicità; split DB futuro facile |
|
||||||
|
| TLS | Cloudflare Origin Certificate | Zero maintenance, trust Cloudflare |
|
||||||
|
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
"""Unit tests for the agent registry, base classes, and tool loop."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _StubAgent(ChatAgent):
|
|
||||||
"""Minimal concrete agent for testing."""
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "stub"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "A stub agent for tests"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"echo: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
class _AnotherAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "another"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Another stub"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return "another"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _fresh_registry():
|
|
||||||
"""Reset the singleton between tests."""
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
yield
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def reg() -> AgentRegistry:
|
|
||||||
return AgentRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tests ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestRegisterAndGet:
|
|
||||||
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
agent = reg.get("stub")
|
|
||||||
assert isinstance(agent, _StubAgent)
|
|
||||||
|
|
||||||
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError, match="not found"):
|
|
||||||
reg.get("nonexistent")
|
|
||||||
|
|
||||||
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
reg.register(_AnotherAgent)
|
|
||||||
assert reg.get("stub").get_name() == "stub"
|
|
||||||
assert reg.get("another").get_name() == "another"
|
|
||||||
|
|
||||||
|
|
||||||
class TestListAgents:
|
|
||||||
def test_empty(self, reg: AgentRegistry) -> None:
|
|
||||||
assert reg.list_agents() == []
|
|
||||||
|
|
||||||
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
agents = reg.list_agents()
|
|
||||||
assert len(agents) == 1
|
|
||||||
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
|
||||||
|
|
||||||
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
reg.register(_AnotherAgent)
|
|
||||||
names = {a["name"] for a in reg.list_agents()}
|
|
||||||
assert names == {"stub", "another"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestCallAgent:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
result = await reg.call_agent("stub", "hello", {})
|
|
||||||
assert result == "echo: hello"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
await reg.call_agent("nope", "hi", {})
|
|
||||||
|
|
||||||
|
|
||||||
class TestSingleton:
|
|
||||||
def test_singleton_identity(self) -> None:
|
|
||||||
a = AgentRegistry()
|
|
||||||
b = AgentRegistry()
|
|
||||||
assert a is b
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolLoop:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_tool_calls(self) -> None:
|
|
||||||
"""When the LLM responds without tool calls, return content directly."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
ai_msg = MagicMock()
|
|
||||||
ai_msg.content = "final answer"
|
|
||||||
ai_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [])
|
|
||||||
assert result == "final answer"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_call_then_answer(self) -> None:
|
|
||||||
"""LLM requests one tool call, gets result, then answers."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
# First response: tool call
|
|
||||||
tool_call_msg = MagicMock()
|
|
||||||
tool_call_msg.content = ""
|
|
||||||
tool_call_msg.tool_calls = [
|
|
||||||
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Second response: final answer
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "done"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
# Mock tool
|
|
||||||
tool = AsyncMock()
|
|
||||||
tool.name = "my_tool"
|
|
||||||
tool.ainvoke = AsyncMock(return_value="tool_result")
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [tool])
|
|
||||||
assert result == "done"
|
|
||||||
tool.ainvoke.assert_called_once_with({"x": 1})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unknown_tool_handled(self) -> None:
|
|
||||||
"""Unknown tool names produce an error message instead of crashing."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
tool_call_msg = MagicMock()
|
|
||||||
tool_call_msg.content = ""
|
|
||||||
tool_call_msg.tool_calls = [
|
|
||||||
{"id": "call_1", "name": "missing", "args": {}}
|
|
||||||
]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "recovered"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [])
|
|
||||||
assert result == "recovered"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_max_iter_reached(self) -> None:
|
|
||||||
"""When max iterations are exhausted, a final no-tools call is made."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
# Every response requests a tool call
|
|
||||||
loop_msg = MagicMock()
|
|
||||||
loop_msg.content = ""
|
|
||||||
loop_msg.tool_calls = [
|
|
||||||
{"id": "call_x", "name": "t", "args": {}}
|
|
||||||
]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "gave up"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
tool = AsyncMock()
|
|
||||||
tool.name = "t"
|
|
||||||
tool.ainvoke = AsyncMock(return_value="ok")
|
|
||||||
|
|
||||||
llm_with_tools = AsyncMock()
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
|
||||||
assert result == "gave up"
|
|
||||||
assert llm_with_tools.ainvoke.call_count == 2
|
|
||||||
@@ -10,13 +10,13 @@ Coverage:
|
|||||||
- run_local_agent — file-read timeout path
|
- run_local_agent — file-read timeout path
|
||||||
- run_local_agent — LLM extraction error path
|
- run_local_agent — LLM extraction error path
|
||||||
- run_cloud_agent — stub returns error immediately
|
- run_cloud_agent — stub returns error immediately
|
||||||
- trigger_pending_runs — overdue local + cloud dispatched
|
- trigger_pending_runs — skipped when config is client-owned
|
||||||
- trigger_pending_runs — non-overdue skipped
|
- trigger_pending_runs — non-overdue skipped
|
||||||
- trigger_pending_runs — device_id filter for local agents
|
- trigger_pending_runs — device_id filter for local agents
|
||||||
|
|
||||||
Integration:
|
Integration:
|
||||||
- POST /agents/{id}/run — 404 on unknown agent
|
- POST /agents/can-create — billing eligibility check
|
||||||
- POST /agents/{id}/run — creates run log + dispatches background task
|
- POST /agents/trigger — creates run log + dispatches background task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path():
|
|||||||
assert kwargs["items_processed"] == 1
|
assert kwargs["items_processed"] == 1
|
||||||
assert kwargs["items_created"] == 1
|
assert kwargs["items_created"] == 1
|
||||||
assert kwargs["errors"] == []
|
assert kwargs["errors"] == []
|
||||||
assert kwargs["update_config_last_run"] is True
|
assert kwargs["update_config_last_run"] is False
|
||||||
|
|
||||||
# Verify agent_run frame was sent.
|
# Verify agent_run frame was sent.
|
||||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||||
@@ -690,31 +690,11 @@ async def test_finalize_run_updates_cloud_config_last_run_at():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_no_overdue():
|
async def test_trigger_pending_runs_no_overdue():
|
||||||
"""If no agents are overdue trigger_pending_runs does nothing."""
|
"""Pending-run scan is skipped because agent config is client-owned."""
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
config = _make_local_config()
|
|
||||||
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
|
||||||
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
|
||||||
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value = mock_ctx
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -722,31 +702,11 @@ async def test_trigger_pending_runs_no_overdue():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_device_id_filter():
|
async def test_trigger_pending_runs_device_id_filter():
|
||||||
"""Local agents are only triggered for the matching device_id."""
|
"""Device filtering is no longer backend-managed in pending runs."""
|
||||||
# The DB query already filters by device_id, so we verify the SELECT
|
|
||||||
# includes the device_id filter by checking that a config bound to a
|
|
||||||
# different device is never dispatched.
|
|
||||||
#
|
|
||||||
# Since trigger_pending_runs queries with device_id == "dev-001",
|
|
||||||
# simulate the DB returning an empty list (as it would for a mismatch).
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager(device_id="dev-001")
|
mgr = _make_manager(device_id="dev-001")
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value = mock_ctx
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -754,56 +714,18 @@ async def test_trigger_pending_runs_device_id_filter():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_dispatches_overdue():
|
async def test_trigger_pending_runs_dispatches_overdue():
|
||||||
"""Overdue local agent triggers run_local_agent sequentially."""
|
"""No pending runs are dispatched by backend after config deprecation."""
|
||||||
config = _make_local_config() # last_run_at=None → always overdue
|
|
||||||
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
call_order: list[str] = []
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
|
||||||
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
|
||||||
call_order.append("run_local")
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
|
||||||
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
|
||||||
# First call: query configs. Subsequent calls: create run_log.
|
|
||||||
mock_query_ctx = AsyncMock()
|
|
||||||
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
|
||||||
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_query_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
|
|
||||||
run_log_obj = AgentRunLog(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
agent_id=config.id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=_FREE_UID,
|
|
||||||
status="running",
|
|
||||||
started_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
mock_insert_ctx = AsyncMock()
|
|
||||||
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
|
||||||
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_insert_ctx.add = MagicMock()
|
|
||||||
mock_insert_ctx.commit = AsyncMock()
|
|
||||||
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
|
||||||
|
|
||||||
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
assert call_order == ["run_local"]
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration: POST /agents/{id}/run
|
# Integration: POST /agents/can-create and /agents/trigger
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -820,50 +742,67 @@ def _override_db(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_run_unknown_agent(client):
|
async def test_can_create_agent_allows_when_under_limit(client):
|
||||||
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
f"/api/v1/agents/{uuid.uuid4()}/run",
|
"/api/v1/agents/can-create",
|
||||||
headers=auth_header("power"),
|
json={"active_agents": 0},
|
||||||
|
headers=auth_header("free"),
|
||||||
)
|
)
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["allowed"] is True
|
||||||
|
assert body["tier"] == "free"
|
||||||
|
assert body["active_agents"] == 0
|
||||||
|
assert body["limit"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_can_create_agent_denies_when_at_limit(client):
|
||||||
|
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/agents/can-create",
|
||||||
|
json={"active_agents": 2},
|
||||||
|
headers=auth_header("free"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["allowed"] is False
|
||||||
|
assert body["limit"] == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||||
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
||||||
# Create the local agent config in the DB.
|
dispatched: list[tuple[str, str]] = []
|
||||||
config = LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=TEST_USER_IDS["power"],
|
|
||||||
device_id="dev-001",
|
|
||||||
name="My Agent",
|
|
||||||
directory_paths=["/home/user/docs"],
|
|
||||||
data_types=["tasks"],
|
|
||||||
prompt_template="Extract tasks.",
|
|
||||||
file_extensions=[".txt"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
)
|
|
||||||
db_session.add(config)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
dispatched: list = []
|
|
||||||
|
|
||||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||||
dispatched.append((user_id, cfg.id))
|
dispatched.append((user_id, cfg.id))
|
||||||
|
|
||||||
|
def _fake_create_task(coro):
|
||||||
|
coro.close()
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||||
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
|
||||||
patch("asyncio.create_task") as mock_create_task:
|
patch("asyncio.create_task") as mock_create_task:
|
||||||
|
mock_create_task.side_effect = _fake_create_task
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
f"/api/v1/agents/{config.id}/run",
|
"/api/v1/agents/trigger",
|
||||||
|
json={
|
||||||
|
"directory": "/home/user/docs",
|
||||||
|
"what_to_extract": ["task", "note"],
|
||||||
|
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
|
||||||
|
"batch_interval": "0 */6 * * *",
|
||||||
|
"custom_agent_prompt": "Extract tasks and notes.",
|
||||||
|
"active_agents": 0,
|
||||||
|
},
|
||||||
headers=auth_header("power"),
|
headers=auth_header("power"),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resp.status_code == 202
|
assert resp.status_code == 202
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["agent_id"] == config.id
|
assert isinstance(data["agent_id"], str)
|
||||||
|
assert data["agent_id"]
|
||||||
assert data["status"] == "running"
|
assert data["status"] == "running"
|
||||||
assert data["agent_type"] == "local"
|
assert data["agent_type"] == "local"
|
||||||
|
|
||||||
|
|||||||
@@ -1,416 +0,0 @@
|
|||||||
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
|
|
||||||
|
|
||||||
# ── Minimal concrete agent for testing ───────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _EchoAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "_echo"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Echo agent for tests"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
|
|
||||||
msg = AIMessage(content=content)
|
|
||||||
if tool_calls:
|
|
||||||
msg.tool_calls = tool_calls
|
|
||||||
else:
|
|
||||||
msg.tool_calls = []
|
|
||||||
return msg
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tool(name: str, return_value: Any) -> MagicMock:
|
|
||||||
t = MagicMock()
|
|
||||||
t.name = name
|
|
||||||
t.ainvoke = AsyncMock(return_value=return_value)
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
|
|
||||||
chunks = []
|
|
||||||
for tok in tokens:
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = tok
|
|
||||||
chunks.append(c)
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
|
|
||||||
tokens: list[str] = []
|
|
||||||
async for tok in agent._tool_loop_stream(llm, messages, tools):
|
|
||||||
tokens.append(tok)
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
# ── tool_results initialised ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_results_init():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: no tool calls ────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_no_tools():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
|
||||||
assert result == "Hello!"
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: with one tool call + result capture ──────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_captures_tool_results():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
# Mock execute_on_client to return structured data via the tool
|
|
||||||
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
|
|
||||||
|
|
||||||
async def fake_executor(payload: dict) -> dict:
|
|
||||||
return raw_result
|
|
||||||
|
|
||||||
# AIMessage with a tool call, then a final answer
|
|
||||||
tool_call_msg = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
final_msg = _make_ai_message("Here are your tasks.")
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
|
|
||||||
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
|
|
||||||
|
|
||||||
from app.core.ws_context import set_client_executor, clear_client_executor
|
|
||||||
set_client_executor(fake_executor)
|
|
||||||
try:
|
|
||||||
# Patch the tool to actually call execute_on_client
|
|
||||||
async def tool_side_effect(args: dict) -> str:
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
res = await execute_on_client(action="select", table="tasks")
|
|
||||||
rows = res.get("rows", [])
|
|
||||||
return "\n".join(r["title"] for r in rows)
|
|
||||||
|
|
||||||
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(
|
|
||||||
llm, [HumanMessage(content="list my tasks")], [mock_tool]
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
assert result == "Here are your tasks."
|
|
||||||
assert len(agent.tool_results) == 1
|
|
||||||
assert agent.tool_results[0] == raw_result
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: tool_results reset on each call ──────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_resets_tool_results():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
|
|
||||||
|
|
||||||
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: unknown tool name ────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_unknown_tool():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
# No known tools — model still calls a non-existent one; loop handles gracefully
|
|
||||||
tool_call_msg = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
final_msg = _make_ai_message("Handled.")
|
|
||||||
|
|
||||||
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
|
|
||||||
assert result == "Handled."
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_max_iter():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
always_tool = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
fallback = _make_ai_message("Fallback.")
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
# Returns tool_call_msg on every iteration
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=fallback)
|
|
||||||
|
|
||||||
mock_tool = _make_tool("t", "ok")
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
|
|
||||||
assert result == "Fallback."
|
|
||||||
assert llm_with_tools.ainvoke.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_no_tools_yields_tokens():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
|
|
||||||
no_tool_msg = _make_ai_message("irrelevant")
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
for tok in ["Hello", " ", "world"]:
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = tok
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
|
|
||||||
assert tokens == ["Hello", " ", "world"]
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: one tool call then streaming final ─────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_with_tool_call():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
|
|
||||||
|
|
||||||
async def fake_executor(payload: dict) -> dict:
|
|
||||||
return raw_result
|
|
||||||
|
|
||||||
tool_call_msg = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
# After tools run, ainvoke returns no more tool calls
|
|
||||||
no_more_tools_msg = _make_ai_message("Task found.")
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
for tok in ["Task", " ", "found."]:
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = tok
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
async def tool_side_effect(args: dict) -> str:
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
|
|
||||||
return res.get("row", {}).get("title", "")
|
|
||||||
|
|
||||||
mock_tool = _make_tool("get_task", "Deploy")
|
|
||||||
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
||||||
|
|
||||||
from app.core.ws_context import set_client_executor, clear_client_executor
|
|
||||||
set_client_executor(fake_executor)
|
|
||||||
try:
|
|
||||||
tokens = await _collect_stream(
|
|
||||||
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
assert tokens == ["Task", " ", "found."]
|
|
||||||
assert len(agent.tool_results) == 1
|
|
||||||
assert agent.tool_results[0] == raw_result
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: tool_results reset on each call ───────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_resets_tool_results():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
agent.tool_results = [{"old": True}]
|
|
||||||
|
|
||||||
no_tool_msg = _make_ai_message("")
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = "ok"
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_skips_empty_chunks():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
no_tool_msg = _make_ai_message("")
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
for tok in ["", "hello", "", " world", ""]:
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = tok
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
|
||||||
assert tokens == ["hello", " world"]
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_max_iter():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
always_tool = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = "fallback"
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
mock_tool = _make_tool("t", "ok")
|
|
||||||
|
|
||||||
tokens = await _collect_stream(
|
|
||||||
agent, llm, [HumanMessage(content="x")], [mock_tool],
|
|
||||||
)
|
|
||||||
assert tokens == ["fallback"]
|
|
||||||
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: multiple tool results captured ────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_multiple_tool_results():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
call_results = [
|
|
||||||
{"rows": [{"id": "t-1"}]},
|
|
||||||
{"rows": [{"id": "t-2"}]},
|
|
||||||
]
|
|
||||||
call_iter = iter(call_results)
|
|
||||||
|
|
||||||
async def fake_executor(payload: dict) -> dict:
|
|
||||||
return next(call_iter)
|
|
||||||
|
|
||||||
# Two tool calls in one iteration
|
|
||||||
tool_call_msg = _make_ai_message(
|
|
||||||
tool_calls=[
|
|
||||||
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
|
|
||||||
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
no_more_tools_msg = _make_ai_message("Done.")
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = "Done."
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
async def tool_side_effect(args: dict) -> str:
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
res = await execute_on_client(action="select", table="tasks")
|
|
||||||
return str(res)
|
|
||||||
|
|
||||||
tool_a = _make_tool("tool_a", "")
|
|
||||||
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
||||||
tool_b = _make_tool("tool_b", "")
|
|
||||||
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
||||||
|
|
||||||
from app.core.ws_context import set_client_executor, clear_client_executor
|
|
||||||
set_client_executor(fake_executor)
|
|
||||||
try:
|
|
||||||
tokens = await _collect_stream(
|
|
||||||
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
assert tokens == ["Done."]
|
|
||||||
assert len(agent.tool_results) == 2
|
|
||||||
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
|
|
||||||
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}
|
|
||||||
@@ -1,761 +0,0 @@
|
|||||||
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
from app.agents.timeline_agent import TimelineAgent
|
|
||||||
from app.agents.note_agent import NoteAgent
|
|
||||||
from app.agents.project_agent import ProjectAgent
|
|
||||||
from app.agents.task_agent import TaskAgent
|
|
||||||
from app.core.agent_registry import registry
|
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
|
||||||
|
|
||||||
|
|
||||||
# ── WS executor mock ──────────────────────────────────────────────────
|
|
||||||
#
|
|
||||||
# Tools call execute_on_client() which reads a ContextVar set by the WS
|
|
||||||
# handler. In unit tests there is no WS session, so we install a fake
|
|
||||||
# executor that returns plausible data for each action type.
|
|
||||||
|
|
||||||
_FAKE_ROW: dict[str, Any] = {
|
|
||||||
"id": "fake-id",
|
|
||||||
"title": "Fake Title",
|
|
||||||
"name": "Fake Name",
|
|
||||||
"status": "todo",
|
|
||||||
"priority": "medium",
|
|
||||||
"content": "Fake content",
|
|
||||||
"date": 1700000000000,
|
|
||||||
"taskId": "fake-task-id",
|
|
||||||
"author": "Alice",
|
|
||||||
"projectId": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _fake_executor(payload: dict) -> dict:
|
|
||||||
action = payload.get("action", "")
|
|
||||||
if action == "select":
|
|
||||||
return {"rows": []}
|
|
||||||
if action == "insert":
|
|
||||||
data = payload.get("data", {})
|
|
||||||
return {"row": {**_FAKE_ROW, **data}}
|
|
||||||
if action == "update":
|
|
||||||
data = payload.get("data", {})
|
|
||||||
row = {**_FAKE_ROW, "id": data.get("id", "fake-id"), **data.get("updates", {})}
|
|
||||||
return {"row": row}
|
|
||||||
if action == "delete":
|
|
||||||
return {"deleted": True}
|
|
||||||
if action == "get":
|
|
||||||
data = payload.get("data", {})
|
|
||||||
return {"row": {**_FAKE_ROW, "id": data.get("id", "fake-id")}}
|
|
||||||
if action == "vector_upsert":
|
|
||||||
return {"ok": True}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def ws_executor():
|
|
||||||
"""Install a fake WS executor for every test so tools can run without a real WS."""
|
|
||||||
set_client_executor(_fake_executor)
|
|
||||||
yield
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm(response_text: str) -> MagicMock:
|
|
||||||
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.content = response_text
|
|
||||||
msg.tool_calls = []
|
|
||||||
llm = MagicMock()
|
|
||||||
bound = MagicMock()
|
|
||||||
bound.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
llm.bind_tools = MagicMock(return_value=bound)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm_with_tool_call(
|
|
||||||
tool_name: str, tool_args: dict[str, Any], final_text: str
|
|
||||||
) -> MagicMock:
|
|
||||||
"""Mock LLM that fires one tool call then returns *final_text*."""
|
|
||||||
tool_msg = MagicMock()
|
|
||||||
tool_msg.content = ""
|
|
||||||
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = final_text
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
bound = MagicMock()
|
|
||||||
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=bound)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
# ── Registration ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentRegistration:
|
|
||||||
def test_all_agents_registered(self) -> None:
|
|
||||||
names = {a["name"] for a in registry.list_agents()}
|
|
||||||
assert {
|
|
||||||
"task_agent", "timeline_agent", "project_agent", "note_agent"
|
|
||||||
}.issubset(names)
|
|
||||||
|
|
||||||
def test_registry_returns_correct_types(self) -> None:
|
|
||||||
assert isinstance(registry.get("task_agent"), TaskAgent)
|
|
||||||
assert isinstance(registry.get("timeline_agent"), TimelineAgent)
|
|
||||||
assert isinstance(registry.get("project_agent"), ProjectAgent)
|
|
||||||
assert isinstance(registry.get("note_agent"), NoteAgent)
|
|
||||||
|
|
||||||
def test_descriptions_present(self) -> None:
|
|
||||||
for agent_info in registry.list_agents():
|
|
||||||
assert agent_info["description"], f"Empty description: {agent_info['name']}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── TaskAgent ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestTaskAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert TaskAgent().get_name() == "task_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(TaskAgent().get_tools()) == 8
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in TaskAgent().get_tools()}
|
|
||||||
assert names == {
|
|
||||||
"list_tasks",
|
|
||||||
"create_task",
|
|
||||||
"update_task",
|
|
||||||
"delete_task",
|
|
||||||
"list_tasks_due_today",
|
|
||||||
"list_task_comments",
|
|
||||||
"add_task_comment",
|
|
||||||
"delete_task_comment",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_returns_string(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Task created.")
|
|
||||||
result = await TaskAgent().handle("create a task", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Here are your tasks.")
|
|
||||||
result = await TaskAgent().handle("list my tasks", {})
|
|
||||||
assert result == "Here are your tasks."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_task_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_task",
|
|
||||||
{"title": "Buy groceries", "priority": "low"},
|
|
||||||
"Task 'Buy groceries' created.",
|
|
||||||
)
|
|
||||||
result = await TaskAgent().handle("add a grocery task", {})
|
|
||||||
assert result == "Task 'Buy groceries' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await TaskAgent().handle("help", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_rich_context(self) -> None:
|
|
||||||
context = {
|
|
||||||
"user_profile": {"id": "u1", "tier": "pro"},
|
|
||||||
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
|
||||||
}
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Tasks listed.")
|
|
||||||
result = await TaskAgent().handle("show tasks", context)
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTaskAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_defaults(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_tasks.ainvoke({})
|
|
||||||
m.assert_called_once_with(
|
|
||||||
action="select", table="tasks",
|
|
||||||
filters={"projectId": None, "status": None, "search": None, "orderBy": None},
|
|
||||||
)
|
|
||||||
assert result == "No tasks found matching the given filters."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_with_status_filter(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
await list_tasks.ainvoke({"status": "done"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["filters"]["status"] == "done"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_task_defaults(self) -> None:
|
|
||||||
from app.agents.task_agent import create_task
|
|
||||||
fake_row = {"id": "t1", "title": "Test task", "status": "todo", "priority": "medium"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await create_task.ainvoke({"title": "Test task"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "insert"
|
|
||||||
assert call_kwargs["table"] == "tasks"
|
|
||||||
assert call_kwargs["data"]["title"] == "Test task"
|
|
||||||
assert call_kwargs["data"]["status"] == "todo"
|
|
||||||
assert call_kwargs["data"]["priority"] == "medium"
|
|
||||||
assert "Test task" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_task_with_all_fields(self) -> None:
|
|
||||||
from app.agents.task_agent import create_task
|
|
||||||
fake_row = {"id": "t1", "title": "Deploy", "status": "in_progress", "priority": "high"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await create_task.ainvoke({
|
|
||||||
"title": "Deploy", "priority": "high", "status": "in_progress",
|
|
||||||
"project_id": "p1", "is_ai_suggested": 1,
|
|
||||||
})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["data"]["priority"] == "high"
|
|
||||||
assert call_kwargs["data"]["status"] == "in_progress"
|
|
||||||
assert call_kwargs["data"]["projectId"] == "p1"
|
|
||||||
assert call_kwargs["data"]["isAiSuggested"] == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_task_with_status(self) -> None:
|
|
||||||
from app.agents.task_agent import update_task
|
|
||||||
fake_row = {"id": "t1", "title": "Buy groceries", "status": "done"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "update"
|
|
||||||
assert call_kwargs["data"]["id"] == "t1"
|
|
||||||
assert call_kwargs["data"]["updates"]["status"] == "done"
|
|
||||||
assert "t1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_task_empty_updates(self) -> None:
|
|
||||||
from app.agents.task_agent import update_task
|
|
||||||
fake_row = {"id": "t1", "title": "Task", "status": "todo"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await update_task.ainvoke({"task_id": "t1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_task(self) -> None:
|
|
||||||
from app.agents.task_agent import delete_task
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_task.ainvoke({"task_id": "t1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["table"] == "tasks"
|
|
||||||
assert call_kwargs["data"]["id"] == "t1"
|
|
||||||
assert "t1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_due_today(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks_due_today
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_tasks_due_today.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "tasks"
|
|
||||||
assert "dueDateFrom" in call_kwargs["filters"]
|
|
||||||
assert result == "No tasks are due today."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_task_comments(self) -> None:
|
|
||||||
from app.agents.task_agent import list_task_comments
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "taskComments"
|
|
||||||
assert call_kwargs["filters"]["taskId"] == "t1"
|
|
||||||
assert "t1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_task_comment(self) -> None:
|
|
||||||
from app.agents.task_agent import add_task_comment
|
|
||||||
fake_row = {"id": "c1", "taskId": "t1", "author": "Alice", "content": "Looks good!"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await add_task_comment.ainvoke({
|
|
||||||
"task_id": "t1", "author": "Alice", "content": "Looks good!",
|
|
||||||
})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "insert"
|
|
||||||
assert call_kwargs["table"] == "taskComments"
|
|
||||||
assert call_kwargs["data"]["taskId"] == "t1"
|
|
||||||
assert call_kwargs["data"]["author"] == "Alice"
|
|
||||||
assert call_kwargs["data"]["content"] == "Looks good!"
|
|
||||||
assert "Alice" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_task_comment(self) -> None:
|
|
||||||
from app.agents.task_agent import delete_task_comment
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["table"] == "taskComments"
|
|
||||||
assert call_kwargs["data"]["id"] == "c1"
|
|
||||||
assert "c1" in result
|
|
||||||
|
|
||||||
|
|
||||||
# ── TimelineAgent ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestTimelineAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert TimelineAgent().get_name() == "timeline_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert TimelineAgent().get_description() == "Manages project timelines (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(TimelineAgent().get_tools()) == 4
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in TimelineAgent().get_tools()}
|
|
||||||
assert names == {"list_timelines", "create_timeline", "update_timeline", "delete_timeline"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("No timelines found.")
|
|
||||||
result = await TimelineAgent().handle("list timelines", {})
|
|
||||||
assert result == "No timelines found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_timeline",
|
|
||||||
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
|
||||||
"Timeline 'MVP Launch' created.",
|
|
||||||
)
|
|
||||||
result = await TimelineAgent().handle("add MVP timeline", {})
|
|
||||||
assert result == "Timeline 'MVP Launch' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await TimelineAgent().handle("show milestones", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTimelineAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_timelines_no_project(self) -> None:
|
|
||||||
from app.agents.timeline_agent import list_timelines
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_timelines.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "timelines"
|
|
||||||
assert call_kwargs["filters"]["projectId"] is None
|
|
||||||
assert result == "No timelines found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_timelines_with_project(self) -> None:
|
|
||||||
from app.agents.timeline_agent import list_timelines
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
await list_timelines.ainvoke({"project_id": "p1"})
|
|
||||||
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_timeline(self) -> None:
|
|
||||||
from app.agents.timeline_agent import create_timeline
|
|
||||||
fake_row = {"id": "cp1", "title": "Beta release", "date": 1700000000000}
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await create_timeline.ainvoke({
|
|
||||||
"project_id": "p1", "title": "Beta release", "date": 1700000000000,
|
|
||||||
})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "insert"
|
|
||||||
assert call_kwargs["table"] == "timelines"
|
|
||||||
assert call_kwargs["data"]["projectId"] == "p1"
|
|
||||||
assert call_kwargs["data"]["title"] == "Beta release"
|
|
||||||
assert call_kwargs["data"]["date"] == 1700000000000
|
|
||||||
assert "Beta release" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_timeline_ai_suggested(self) -> None:
|
|
||||||
from app.agents.timeline_agent import create_timeline
|
|
||||||
fake_row = {"id": "cp1", "title": "Review", "date": 1700000000000}
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await create_timeline.ainvoke({
|
|
||||||
"project_id": "p1", "title": "Review", "date": 1700000000000, "is_ai_suggested": 1,
|
|
||||||
})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["data"]["isAiSuggested"] == 1
|
|
||||||
assert call_kwargs["data"]["isApproved"] == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_timeline_approve(self) -> None:
|
|
||||||
from app.agents.timeline_agent import update_timeline
|
|
||||||
fake_row = {"id": "c1", "title": "MVP", "isApproved": 1}
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await update_timeline.ainvoke({"timeline_id": "c1", "is_approved": 1})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "update"
|
|
||||||
assert call_kwargs["data"]["id"] == "c1"
|
|
||||||
assert call_kwargs["data"]["updates"]["isApproved"] == 1
|
|
||||||
assert "c1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_timeline_empty_updates(self) -> None:
|
|
||||||
from app.agents.timeline_agent import update_timeline
|
|
||||||
fake_row = {"id": "c1", "title": "MVP"}
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await update_timeline.ainvoke({"timeline_id": "c1"})
|
|
||||||
assert m.call_args.kwargs["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_timeline(self) -> None:
|
|
||||||
from app.agents.timeline_agent import delete_timeline
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_timeline.ainvoke({"timeline_id": "c1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["table"] == "timelines"
|
|
||||||
assert call_kwargs["data"]["id"] == "c1"
|
|
||||||
assert "c1" in result
|
|
||||||
|
|
||||||
|
|
||||||
# ── ProjectAgent ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert ProjectAgent().get_name() == "project_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(ProjectAgent().get_tools()) == 6
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in ProjectAgent().get_tools()}
|
|
||||||
assert names == {
|
|
||||||
"list_projects",
|
|
||||||
"list_all_projects",
|
|
||||||
"get_project",
|
|
||||||
"create_project",
|
|
||||||
"update_project",
|
|
||||||
"delete_project",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
|
||||||
result = await ProjectAgent().handle("show my projects", {})
|
|
||||||
assert result == "Project Alpha is active."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_project_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_project",
|
|
||||||
{"name": "Pippo"},
|
|
||||||
"Project 'Pippo' created.",
|
|
||||||
)
|
|
||||||
result = await ProjectAgent().handle("create project Pippo", {})
|
|
||||||
assert result == "Project 'Pippo' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await ProjectAgent().handle("archive old project", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_projects_defaults(self) -> None:
|
|
||||||
from app.agents.project_agent import list_projects
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_projects.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "projects"
|
|
||||||
assert call_kwargs["filters"]["includeArchived"] is False
|
|
||||||
assert result == "No projects found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_projects_include_archived(self) -> None:
|
|
||||||
from app.agents.project_agent import list_projects
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
await list_projects.ainvoke({"include_archived": 1})
|
|
||||||
assert m.call_args.kwargs["filters"]["includeArchived"] is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_all_projects(self) -> None:
|
|
||||||
from app.agents.project_agent import list_all_projects
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_all_projects.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "projects"
|
|
||||||
assert result == "No projects found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_project(self) -> None:
|
|
||||||
from app.agents.project_agent import get_project
|
|
||||||
fake_row = {"id": "p1", "name": "Alpha", "status": "active", "clientId": None}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await get_project.ainvoke({"project_id": "p1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "get"
|
|
||||||
assert call_kwargs["table"] == "projects"
|
|
||||||
assert call_kwargs["data"]["id"] == "p1"
|
|
||||||
assert "Alpha" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_project_name_only(self) -> None:
|
|
||||||
from app.agents.project_agent import create_project
|
|
||||||
fake_row = {"id": "p1", "name": "Alpha"}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await create_project.ainvoke({"name": "Alpha"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "insert"
|
|
||||||
assert call_kwargs["data"]["name"] == "Alpha"
|
|
||||||
assert call_kwargs["data"]["clientId"] is None
|
|
||||||
assert "Alpha" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_project_with_client(self) -> None:
|
|
||||||
from app.agents.project_agent import create_project
|
|
||||||
fake_row = {"id": "p1", "name": "Beta"}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
|
||||||
assert m.call_args.kwargs["data"]["clientId"] == "cl1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_project_archive(self) -> None:
|
|
||||||
from app.agents.project_agent import update_project
|
|
||||||
fake_row = {"id": "p1", "name": "Alpha", "status": "archived"}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "update"
|
|
||||||
assert call_kwargs["data"]["id"] == "p1"
|
|
||||||
assert call_kwargs["data"]["updates"]["status"] == "archived"
|
|
||||||
assert "p1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_project_empty_updates(self) -> None:
|
|
||||||
from app.agents.project_agent import update_project
|
|
||||||
fake_row = {"id": "p1", "name": "Alpha", "status": "active"}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await update_project.ainvoke({"project_id": "p1"})
|
|
||||||
assert m.call_args.kwargs["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_project(self) -> None:
|
|
||||||
from app.agents.project_agent import delete_project
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_project.ainvoke({"project_id": "p1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["data"]["id"] == "p1"
|
|
||||||
assert "p1" in result
|
|
||||||
|
|
||||||
|
|
||||||
# ── NoteAgent ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoteAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert NoteAgent().get_name() == "note_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(NoteAgent().get_tools()) == 5
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in NoteAgent().get_tools()}
|
|
||||||
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Note created.")
|
|
||||||
result = await NoteAgent().handle("create a note", {})
|
|
||||||
assert result == "Note created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_note_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_note",
|
|
||||||
{"title": "Daily log", "content": "# Today\nAll good."},
|
|
||||||
"Note 'Daily log' created.",
|
|
||||||
)
|
|
||||||
result = await NoteAgent().handle("log today's progress", {})
|
|
||||||
assert result == "Note 'Daily log' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await NoteAgent().handle("show notes", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoteAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_notes_no_project(self) -> None:
|
|
||||||
from app.agents.note_agent import list_notes
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_notes.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "notes"
|
|
||||||
assert call_kwargs["filters"]["projectId"] is None
|
|
||||||
assert result == "No notes found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_notes_with_project(self) -> None:
|
|
||||||
from app.agents.note_agent import list_notes
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
await list_notes.ainvoke({"project_id": "p1"})
|
|
||||||
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_note(self) -> None:
|
|
||||||
from app.agents.note_agent import get_note
|
|
||||||
fake_row = {"id": "n1", "title": "Daily log", "content": "# Today\nAll good."}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await get_note.ainvoke({"note_id": "n1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "get"
|
|
||||||
assert call_kwargs["table"] == "notes"
|
|
||||||
assert call_kwargs["data"]["id"] == "n1"
|
|
||||||
assert "Daily log" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_note_minimal(self) -> None:
|
|
||||||
from app.agents.note_agent import create_note
|
|
||||||
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
|
||||||
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
me.return_value = [0.0] * 1536
|
|
||||||
result = await create_note.ainvoke({"title": "Daily log", "content": "# Today\nAll good."})
|
|
||||||
# First call: insert; second call: vector_upsert
|
|
||||||
first_call = m.call_args_list[0].kwargs
|
|
||||||
assert first_call["action"] == "insert"
|
|
||||||
assert first_call["table"] == "notes"
|
|
||||||
assert first_call["data"]["title"] == "Daily log"
|
|
||||||
assert first_call["data"]["content"] == "# Today\nAll good."
|
|
||||||
assert first_call["data"]["projectId"] is None
|
|
||||||
assert "Daily log" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_note_with_project(self) -> None:
|
|
||||||
from app.agents.note_agent import create_note
|
|
||||||
fake_row = {"id": "n1", "title": "Sprint notes", "projectId": "p1"}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
|
||||||
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
me.return_value = [0.0] * 1536
|
|
||||||
await create_note.ainvoke({"title": "Sprint notes", "content": "## Sprint 1", "project_id": "p1"})
|
|
||||||
first_call = m.call_args_list[0].kwargs
|
|
||||||
assert first_call["data"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_note_content_only(self) -> None:
|
|
||||||
from app.agents.note_agent import update_note
|
|
||||||
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
|
||||||
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
me.return_value = [0.0] * 1536
|
|
||||||
result = await update_note.ainvoke({"note_id": "n1", "content": "# Updated content"})
|
|
||||||
first_call = m.call_args_list[0].kwargs
|
|
||||||
assert first_call["action"] == "update"
|
|
||||||
assert first_call["data"]["id"] == "n1"
|
|
||||||
assert first_call["data"]["updates"]["content"] == "# Updated content"
|
|
||||||
assert "title" not in first_call["data"]["updates"]
|
|
||||||
assert "n1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_note_empty_updates(self) -> None:
|
|
||||||
from app.agents.note_agent import update_note
|
|
||||||
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await update_note.ainvoke({"note_id": "n1"})
|
|
||||||
assert m.call_args.kwargs["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_note(self) -> None:
|
|
||||||
from app.agents.note_agent import delete_note
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_note.ainvoke({"note_id": "n1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["table"] == "notes"
|
|
||||||
assert call_kwargs["data"]["id"] == "n1"
|
|
||||||
assert "n1" in result
|
|
||||||
288
tests/test_deep_agent.py
Normal file
288
tests/test_deep_agent.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.core.deep_agent import (
|
||||||
|
_infer_floating_domain,
|
||||||
|
_normalize_tagged_list_lines,
|
||||||
|
run_floating,
|
||||||
|
run_floating_stream,
|
||||||
|
run_home,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTool:
|
||||||
|
name = "list_tasks"
|
||||||
|
|
||||||
|
async def ainvoke(self, args):
|
||||||
|
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLLM:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.agent_calls = 0
|
||||||
|
|
||||||
|
def bind_tools(self, _tools):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ainvoke(self, messages):
|
||||||
|
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
|
||||||
|
if "strict domain classifier" in system_prompt:
|
||||||
|
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
|
||||||
|
|
||||||
|
self.agent_calls += 1
|
||||||
|
if self.agent_calls == 1:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call-1",
|
||||||
|
"name": "list_tasks",
|
||||||
|
"args": {"project_id": "proj-1"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||||
|
assert tool_messages, "Expected at least one tool message"
|
||||||
|
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
|
||||||
|
|
||||||
|
async def astream(self, _messages):
|
||||||
|
yield SimpleNamespace(content="stream-")
|
||||||
|
yield SimpleNamespace(content="ok")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_home_uses_mocked_tool_result():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
out = await run_home("user-1", "list my tasks", {})
|
||||||
|
|
||||||
|
assert "Final answer from mocked tool" in out
|
||||||
|
assert "Mock Task" in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"show me timeline updates",
|
||||||
|
{"scope": {"type": "timeline", "id": "tl-1"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert events[0] == (
|
||||||
|
"floating_domain",
|
||||||
|
{"type": "timeline", "id": "tl-1", "section": None},
|
||||||
|
)
|
||||||
|
assert ("token", "stream-") in events
|
||||||
|
assert ("token", "ok") in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
||||||
|
class _ClassifierOnlyLLM:
|
||||||
|
async def ainvoke(self, _messages):
|
||||||
|
return AIMessage(
|
||||||
|
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
||||||
|
domain = await _infer_floating_domain(
|
||||||
|
"Quali sono i miei task per il progetto X",
|
||||||
|
{
|
||||||
|
"scope": {"type": "timeline"},
|
||||||
|
"resolved_project_id": "213213-312321-312312-421321",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert domain == {
|
||||||
|
"type": "project",
|
||||||
|
"id": "213213-312321-312312-421321",
|
||||||
|
"section": "task",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
||||||
|
raw = (
|
||||||
|
"Certo!\n\n"
|
||||||
|
"1. **Task A** — priorita high <task>[task-1]</task>\n"
|
||||||
|
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
|
||||||
|
|
||||||
|
assert "<task>[task-1]</task>" in out
|
||||||
|
assert "<task>[task-2]</task>" in out
|
||||||
|
assert "Task A" not in out
|
||||||
|
assert "Task B" not in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
|
||||||
|
today = date.today()
|
||||||
|
tomorrow = today + timedelta(days=1)
|
||||||
|
yesterday = today - timedelta(days=1)
|
||||||
|
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
|
||||||
|
|
||||||
|
raw = "\n".join(
|
||||||
|
[
|
||||||
|
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
|
||||||
|
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
|
||||||
|
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
|
||||||
|
|
||||||
|
assert "<timeline>[tl-next]</timeline>" in out
|
||||||
|
assert "<timeline>[tl-old]</timeline>" not in out
|
||||||
|
assert "<timeline>[tl-future]</timeline>" not in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_run_single_agent(**_kwargs):
|
||||||
|
return (
|
||||||
|
"Hai 1 task:\\n"
|
||||||
|
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
|
):
|
||||||
|
text, _domain = await run_floating(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "<task>" not in text
|
||||||
|
assert "</task>" not in text
|
||||||
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
yield "token", "Hai 1 task:\\n"
|
||||||
|
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
||||||
|
combined = "".join(token_events)
|
||||||
|
assert "<task>" not in combined
|
||||||
|
assert "</task>" not in combined
|
||||||
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
||||||
|
class _NoChunkLLM:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
def bind_tools(self, _tools):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ainvoke(self, _messages):
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 1:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call-1",
|
||||||
|
"name": "list_tasks",
|
||||||
|
"args": {},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return AIMessage(content="No notes found.")
|
||||||
|
|
||||||
|
async def astream(self, _messages):
|
||||||
|
if False:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali sono le note?",
|
||||||
|
{"scope": {"type": "note"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert events[0][0] == "floating_domain"
|
||||||
|
assert ("token", "No notes found.") in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_run_single_agent(**_kwargs):
|
||||||
|
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
|
):
|
||||||
|
text, _domain = await run_floating(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == "No results found."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert ("token", "No results found.") in events
|
||||||
@@ -1,286 +0,0 @@
|
|||||||
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.execution_plan import (
|
|
||||||
ExecutionPlanBuilder,
|
|
||||||
PlanCache,
|
|
||||||
PromptTemplateRegistry,
|
|
||||||
plan_cache,
|
|
||||||
template_registry,
|
|
||||||
)
|
|
||||||
from app.schemas import ExecutionPlan
|
|
||||||
|
|
||||||
|
|
||||||
# ── PromptTemplateRegistry ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestPromptTemplateRegistry:
|
|
||||||
def test_register_and_get(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_foo", "You are a foo agent.")
|
|
||||||
assert reg.get("tpl_foo") == "You are a foo agent."
|
|
||||||
|
|
||||||
def test_get_unknown_raises_key_error(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
with pytest.raises(KeyError, match="tpl_missing"):
|
|
||||||
reg.get("tpl_missing")
|
|
||||||
|
|
||||||
def test_has_returns_true_for_registered(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_x", "prompt text")
|
|
||||||
assert reg.has("tpl_x") is True
|
|
||||||
|
|
||||||
def test_has_returns_false_for_unregistered(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
assert reg.has("tpl_missing") is False
|
|
||||||
|
|
||||||
def test_list_ids_returns_all_registered_ids(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_a", "a")
|
|
||||||
reg.register("tpl_b", "b")
|
|
||||||
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
|
|
||||||
|
|
||||||
def test_list_ids_does_not_return_prompt_text(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_secret", "top secret prompt")
|
|
||||||
ids = reg.list_ids()
|
|
||||||
assert "top secret prompt" not in ids
|
|
||||||
|
|
||||||
def test_overwrite_existing_template(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_x", "v1")
|
|
||||||
reg.register("tpl_x", "v2")
|
|
||||||
assert reg.get("tpl_x") == "v2"
|
|
||||||
|
|
||||||
def test_empty_registry_has_no_ids(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
assert reg.list_ids() == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestExecutionPlanBuilder:
|
|
||||||
def test_builds_empty_plan(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").build()
|
|
||||||
assert plan.agent == "task_agent"
|
|
||||||
assert plan.steps == []
|
|
||||||
|
|
||||||
def test_add_step_basic(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("create_task", {"priority": "high"})
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert len(plan.steps) == 1
|
|
||||||
assert plan.steps[0].action == "create_task"
|
|
||||||
assert plan.steps[0].variables == {"priority": "high"}
|
|
||||||
assert plan.steps[0].prompt_template is None
|
|
||||||
assert plan.steps[0].data_from_step is None
|
|
||||||
|
|
||||||
def test_add_step_no_params(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
|
|
||||||
assert plan.steps[0].variables is None
|
|
||||||
|
|
||||||
def test_add_llm_step(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_llm_step("tpl_task_default", {"message": "hi"})
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[0].action == "llm"
|
|
||||||
assert plan.steps[0].prompt_template == "tpl_task_default"
|
|
||||||
assert plan.steps[0].variables == {"message": "hi"}
|
|
||||||
|
|
||||||
def test_add_llm_step_no_variables(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
|
|
||||||
assert plan.steps[0].variables is None
|
|
||||||
|
|
||||||
def test_add_data_step(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("fetch_data")
|
|
||||||
.add_data_step("transform", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[1].action == "transform"
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_fluent_chaining_returns_builder(self) -> None:
|
|
||||||
builder = ExecutionPlanBuilder("analytics_agent")
|
|
||||||
result = builder.add_step("a")
|
|
||||||
assert result is builder
|
|
||||||
|
|
||||||
def test_fluent_chain_multiple_steps(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("analytics_agent")
|
|
||||||
.add_llm_step("tpl_analytics_default")
|
|
||||||
.add_step("format_output")
|
|
||||||
.add_data_step("store", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert len(plan.steps) == 3
|
|
||||||
|
|
||||||
def test_build_validates_data_from_step_out_of_range(self) -> None:
|
|
||||||
with pytest.raises(ValueError, match="data_from_step"):
|
|
||||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
|
|
||||||
|
|
||||||
def test_build_validates_data_from_step_self_reference(self) -> None:
|
|
||||||
"""data_from_step=0 on the first step (index 0) is invalid."""
|
|
||||||
with pytest.raises(ValueError, match="data_from_step"):
|
|
||||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
|
|
||||||
|
|
||||||
def test_build_validates_data_from_step_negative(self) -> None:
|
|
||||||
with pytest.raises(ValueError, match="data_from_step"):
|
|
||||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
|
|
||||||
|
|
||||||
def test_valid_data_from_step_at_index_two(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("step0")
|
|
||||||
.add_step("step1")
|
|
||||||
.add_data_step("step2", data_from_step=1)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[2].data_from_step == 1
|
|
||||||
|
|
||||||
def test_data_from_step_zero_valid_at_index_one(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("step0")
|
|
||||||
.add_data_step("step1", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_build_returns_new_plan_each_call(self) -> None:
|
|
||||||
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
|
|
||||||
plan1 = builder.build()
|
|
||||||
plan2 = builder.build()
|
|
||||||
assert plan1 is not plan2
|
|
||||||
assert plan1.steps == plan2.steps
|
|
||||||
|
|
||||||
def test_plan_is_execution_plan_instance(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").build()
|
|
||||||
assert isinstance(plan, ExecutionPlan)
|
|
||||||
|
|
||||||
|
|
||||||
# ── PlanCache ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestPlanCache:
|
|
||||||
def _plan(self, agent: str = "a") -> ExecutionPlan:
|
|
||||||
return ExecutionPlanBuilder(agent).build()
|
|
||||||
|
|
||||||
def test_cache_and_get(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
plan = self._plan()
|
|
||||||
cache.cache_plan("key1", plan)
|
|
||||||
assert cache.get_plan("key1") is plan
|
|
||||||
|
|
||||||
def test_get_missing_returns_none(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
assert cache.get_plan("nonexistent") is None
|
|
||||||
|
|
||||||
def test_get_all_playbooks_empty(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
assert cache.get_all_playbooks() == []
|
|
||||||
|
|
||||||
def test_get_all_playbooks_returns_all_stored(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
p1, p2 = self._plan("a"), self._plan("b")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k2", p2)
|
|
||||||
playbooks = cache.get_all_playbooks()
|
|
||||||
assert len(playbooks) == 2
|
|
||||||
assert p1 in playbooks
|
|
||||||
assert p2 in playbooks
|
|
||||||
|
|
||||||
def test_lru_evicts_oldest_entry(self) -> None:
|
|
||||||
cache = PlanCache(maxsize=2)
|
|
||||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k2", p2)
|
|
||||||
cache.cache_plan("k3", p3) # k1 should be evicted
|
|
||||||
assert cache.get_plan("k1") is None
|
|
||||||
assert cache.get_plan("k2") is p2
|
|
||||||
assert cache.get_plan("k3") is p3
|
|
||||||
|
|
||||||
def test_lru_access_updates_recency(self) -> None:
|
|
||||||
cache = PlanCache(maxsize=2)
|
|
||||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k2", p2)
|
|
||||||
cache.get_plan("k1") # k1 is now most-recently used
|
|
||||||
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
|
|
||||||
assert cache.get_plan("k1") is p1
|
|
||||||
assert cache.get_plan("k2") is None
|
|
||||||
assert cache.get_plan("k3") is p3
|
|
||||||
|
|
||||||
def test_overwrite_existing_key(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
p1, p2 = self._plan("a"), self._plan("b")
|
|
||||||
cache.cache_plan("same_key", p1)
|
|
||||||
cache.cache_plan("same_key", p2)
|
|
||||||
assert cache.get_plan("same_key") is p2
|
|
||||||
assert len(cache.get_all_playbooks()) == 1
|
|
||||||
|
|
||||||
def test_overwrite_does_not_consume_capacity(self) -> None:
|
|
||||||
cache = PlanCache(maxsize=2)
|
|
||||||
p1, p2 = self._plan("a"), self._plan("b")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k1", p2) # overwrite, not a new slot
|
|
||||||
cache.cache_plan("k2", p1) # should fit without eviction
|
|
||||||
assert cache.get_plan("k1") is p2
|
|
||||||
assert cache.get_plan("k2") is p1
|
|
||||||
|
|
||||||
|
|
||||||
# ── Module-level singletons ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestModuleSingletons:
|
|
||||||
def test_template_registry_has_all_agent_defaults(self) -> None:
|
|
||||||
for agent in ("task_agent", "timeline_agent", "project_agent", "note_agent"):
|
|
||||||
assert template_registry.has(f"tpl_{agent}_default"), (
|
|
||||||
f"Missing template: tpl_{agent}_default"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_template_registry_has_operation_templates(self) -> None:
|
|
||||||
assert template_registry.has("tpl_task_extract_from_project")
|
|
||||||
assert template_registry.has("tpl_note_weekly_summary")
|
|
||||||
|
|
||||||
def test_template_registry_get_returns_non_empty_string(self) -> None:
|
|
||||||
text = template_registry.get("tpl_task_agent_default")
|
|
||||||
assert isinstance(text, str)
|
|
||||||
assert len(text) > 0
|
|
||||||
|
|
||||||
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
|
|
||||||
assert len(plan_cache.get_all_playbooks()) >= 2
|
|
||||||
|
|
||||||
def test_playbook_create_tasks_from_project(self) -> None:
|
|
||||||
plan = plan_cache.get_plan("create_tasks_from_project")
|
|
||||||
assert plan is not None
|
|
||||||
assert plan.agent == "project_agent"
|
|
||||||
assert len(plan.steps) == 2
|
|
||||||
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_playbook_generate_weekly_note(self) -> None:
|
|
||||||
plan = plan_cache.get_plan("generate_weekly_note")
|
|
||||||
assert plan is not None
|
|
||||||
assert plan.agent == "note_agent"
|
|
||||||
assert len(plan.steps) == 2
|
|
||||||
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
|
|
||||||
"""Plans must not embed prompt text — only template IDs."""
|
|
||||||
for plan in plan_cache.get_all_playbooks():
|
|
||||||
for step in plan.steps:
|
|
||||||
if step.prompt_template is not None:
|
|
||||||
assert step.prompt_template.startswith("tpl_"), (
|
|
||||||
f"prompt_template looks like raw text: {step.prompt_template!r}"
|
|
||||||
)
|
|
||||||
@@ -110,6 +110,32 @@ async def test_enrich_context_returns_episodic_memory(db_session, user_with_key)
|
|||||||
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
|
||||||
|
target_session = str(uuid.uuid4())
|
||||||
|
other_session = str(uuid.uuid4())
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("Target session memory"),
|
||||||
|
session_id=target_session,
|
||||||
|
))
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("Other session memory"),
|
||||||
|
session_id=other_session,
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
|
||||||
|
|
||||||
|
episodic = ctx.get("episodic_memory", [])
|
||||||
|
assert any("Target session" in s for s in episodic)
|
||||||
|
assert not any("Other session" in s for s in episodic)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
# Add one pattern above threshold and one below
|
# Add one pattern above threshold and one below
|
||||||
@@ -229,6 +255,40 @@ async def test_update_core_upsert(db_session, user_with_key):
|
|||||||
assert _dec(rows[0].value_encrypted) == "fr"
|
assert _dec(rows[0].value_encrypted) == "fr"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_core_block_edit_ops(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
await middleware.update_core(USER_ID, "human", "Name: Roberto")
|
||||||
|
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
|
||||||
|
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
|
||||||
|
|
||||||
|
blocks = await middleware.list_core_blocks(USER_ID)
|
||||||
|
human = next(b for b in blocks if b["label"] == "human")
|
||||||
|
|
||||||
|
assert replaced is True
|
||||||
|
assert "Name: Robert" in human["value"]
|
||||||
|
assert "Timezone: Europe/Rome" in human["value"]
|
||||||
|
|
||||||
|
deleted = await middleware.delete_core(USER_ID, "human")
|
||||||
|
assert deleted is True
|
||||||
|
assert await middleware.get_core_block(USER_ID, "human") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
|
||||||
|
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
|
||||||
|
|
||||||
|
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
|
||||||
|
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
|
||||||
|
|
||||||
|
assert any("whitelist" in item.lower() for item in arch)
|
||||||
|
assert any("delayed" in item.lower() for item in rec)
|
||||||
|
|
||||||
|
|
||||||
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
||||||
|
|
||||||
def test_home_request_calls_memory_middleware(client):
|
def test_home_request_calls_memory_middleware(client):
|
||||||
@@ -240,25 +300,24 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def enrich_context(self, user_id, message):
|
async def enrich_context(self, user_id, message, **kwargs):
|
||||||
enrich_calls.append((user_id, message))
|
enrich_calls.append((user_id, message))
|
||||||
return {"core_memory": {"tz": "UTC"}}
|
return {"core_memory": {"tz": "UTC"}}
|
||||||
|
|
||||||
async def store_episode(self, user_id, session_id, message, response):
|
async def store_episode(self, user_id, session_id, message, response, **kwargs):
|
||||||
store_calls.append((user_id, session_id, message, response))
|
store_calls.append((user_id, session_id, message, response))
|
||||||
|
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
async def _mock_stream(user_id, message, context, reg=None):
|
async def _mock_stream(user_id, message, context):
|
||||||
# Verify memory context was injected
|
# Verify memory context was injected
|
||||||
assert context.get("core_memory") == {"tz": "UTC"}
|
assert context.get("core_memory") == {"tz": "UTC"}
|
||||||
yield "task_agent", ""
|
yield "token", "Done"
|
||||||
yield "task_agent", '{"type": "text", "content": "Done"}'
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
||||||
patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_stream),
|
patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
|
||||||
):
|
):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from jose import jwt
|
|||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.schemas import ChatResponse
|
|
||||||
from tests.conftest import TEST_USER_IDS
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -50,7 +49,6 @@ _CHAT_BODY = {
|
|||||||
"recent_tasks": [],
|
"recent_tasks": [],
|
||||||
"conversation_history": [],
|
"conversation_history": [],
|
||||||
},
|
},
|
||||||
"execution_mode": "direct",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -240,7 +238,7 @@ class TestRateLimitMiddleware:
|
|||||||
|
|
||||||
|
|
||||||
class TestSanitizerMiddleware:
|
class TestSanitizerMiddleware:
|
||||||
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
|
"""Mock ``run_home`` to inject controlled strings into chat responses."""
|
||||||
|
|
||||||
_CHAT_PATH = "/api/v1/chat"
|
_CHAT_PATH = "/api/v1/chat"
|
||||||
|
|
||||||
@@ -248,11 +246,10 @@ class TestSanitizerMiddleware:
|
|||||||
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
||||||
|
|
||||||
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
||||||
mock_response = ChatResponse(response=response_text, actions=[])
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.chat.orchestrate",
|
"app.api.routes.chat.run_home",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=mock_response,
|
return_value=response_text,
|
||||||
):
|
):
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
self._CHAT_PATH,
|
self._CHAT_PATH,
|
||||||
|
|||||||
@@ -1,347 +0,0 @@
|
|||||||
"""Integration tests for the orchestrator module."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
from app.core.orchestrator import (
|
|
||||||
classify_intent,
|
|
||||||
orchestrate,
|
|
||||||
orchestrate_stream,
|
|
||||||
route_pipeline,
|
|
||||||
route_single,
|
|
||||||
)
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
|
||||||
|
|
||||||
|
|
||||||
# ── Stub agents ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _TaskAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "task_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages tasks: create, update, list, suggest"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"task: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
class _CalendarAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "calendar_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Calendar management: events, conflicts, scheduling"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"calendar: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm(response_text: str) -> MagicMock:
|
|
||||||
"""Return a mock LLM that always produces *response_text*."""
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.content = response_text
|
|
||||||
llm = MagicMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _fresh_registry():
|
|
||||||
"""Reset the AgentRegistry singleton between tests."""
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
yield
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def reg() -> AgentRegistry:
|
|
||||||
r = AgentRegistry()
|
|
||||||
r.register(_TaskAgent)
|
|
||||||
r.register(_CalendarAgent)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
# ── classify_intent ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestClassifyIntent:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
result = await classify_intent("add a task", {}, reg)
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
|
||||||
result = await classify_intent("schedule a meeting", {}, reg)
|
|
||||||
assert result == "calendar_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("nonexistent_agent")
|
|
||||||
result = await classify_intent("do something", {}, reg)
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
|
|
||||||
empty_reg = AgentRegistry()
|
|
||||||
# No LLM should be instantiated — early return path
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
result = await classify_intent("anything", {}, empty_reg)
|
|
||||||
mock_cls.assert_not_called()
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm(" task_agent \n")
|
|
||||||
result = await classify_intent("create task", {}, reg)
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
|
|
||||||
# ── route_single ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestRouteSingle:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
|
||||||
result = await route_single("task_agent", "create a task", {}, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
|
|
||||||
result = await route_single("task_agent", "create a task", {}, reg)
|
|
||||||
assert result.response == "task: create a task"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
await route_single("nonexistent", "hello", {}, reg)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
|
|
||||||
result = await route_single("task_agent", "hi", {}, reg)
|
|
||||||
assert result.actions == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── route_pipeline ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestRoutePipeline:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("synthesized result")
|
|
||||||
result = await route_pipeline(
|
|
||||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
|
||||||
)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("synthesized result")
|
|
||||||
result = await route_pipeline(
|
|
||||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
|
||||||
)
|
|
||||||
assert result.response == "synthesized result"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_passes_previous_results_to_subsequent_agents(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
"""Each agent after the first should receive prior outputs in context."""
|
|
||||||
received_contexts: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
class _CapturingAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "capture"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "captures context for testing"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
received_contexts.append(dict(context))
|
|
||||||
return "captured"
|
|
||||||
|
|
||||||
reg.register(_CapturingAgent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("done")
|
|
||||||
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
|
|
||||||
|
|
||||||
# The second agent (capture) must have received previous results
|
|
||||||
assert len(received_contexts) == 1
|
|
||||||
assert "previous_results" in received_contexts[0]
|
|
||||||
assert received_contexts[0]["previous_results"] == ["task: hi"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("single result")
|
|
||||||
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
|
|
||||||
assert result.response == "single result"
|
|
||||||
|
|
||||||
|
|
||||||
# ── orchestrate ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestOrchestrate:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_direct_mode_returns_chat_response(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
assert result.response == "task: add a task"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_returns_execution_plan(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="plan my tasks", execution_mode="plan")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_agent_matches_classified(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
|
||||||
request = ChatRequest(
|
|
||||||
message="schedule something", execution_mode="plan"
|
|
||||||
)
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
assert result.agent == "calendar_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
assert len(result.steps) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_template_id_contains_agent_name(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
assert result.steps[0].prompt_template is not None
|
|
||||||
assert "task_agent" in result.steps[0].prompt_template
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_default_execution_mode_is_direct(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
# execution_mode defaults to "direct"
|
|
||||||
request = ChatRequest(message="help me")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
|
|
||||||
# ── orchestrate_stream ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestOrchestrateStream:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
assert len(chunks) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_all_chunks_are_plain_text(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
|
|
||||||
# orchestrate_stream yields plain text chunks only — no JSON final frame
|
|
||||||
for chunk in chunks:
|
|
||||||
assert isinstance(chunk, str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_concatenated_chunks_equal_full_response(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="create a task", execution_mode="direct")
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
|
|
||||||
full_text = "".join(chunks)
|
|
||||||
assert full_text == "task: create a task"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_chunks_before_final_frame(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(
|
|
||||||
message="x" * 200, execution_mode="direct"
|
|
||||||
) # long enough to produce multiple chunks
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
|
|
||||||
# All but the last chunk should be plain text (not valid final JSON)
|
|
||||||
non_final = chunks[:-1]
|
|
||||||
for chunk in non_final:
|
|
||||||
try:
|
|
||||||
parsed = json.loads(chunk)
|
|
||||||
assert parsed.get("done") is not True
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass # plain text chunk — expected
|
|
||||||
@@ -1,236 +0,0 @@
|
|||||||
"""Tests for v3 orchestrator functions (Step 3)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, AgentRegistry
|
|
||||||
from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream
|
|
||||||
|
|
||||||
|
|
||||||
# ── Minimal agent for testing ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _FixedAgent(ChatAgent):
|
|
||||||
def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._name = name
|
|
||||||
self._tokens = tokens or ["Hello", " world"]
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Fixed agent for tests"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return "".join(self._tokens)
|
|
||||||
|
|
||||||
async def handle_stream(self, query: str, context: dict[str, Any]):
|
|
||||||
for tok in self._tokens:
|
|
||||||
yield tok
|
|
||||||
|
|
||||||
|
|
||||||
# ── Mock registry factory ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock:
|
|
||||||
reg = MagicMock(spec=AgentRegistry)
|
|
||||||
reg.list_agents.return_value = [{"name": agent_name, "description": "test"}]
|
|
||||||
reg.get.return_value = agent
|
|
||||||
return reg
|
|
||||||
|
|
||||||
|
|
||||||
# ── orchestrate_v3 ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_returns_agent_name_and_instance():
|
|
||||||
agent = _FixedAgent("task_agent")
|
|
||||||
reg = _make_registry("task_agent", agent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
|
||||||
name, inst = await orchestrate_v3(
|
|
||||||
user_id="u-1", message="fix a bug", context={}, reg=reg
|
|
||||||
)
|
|
||||||
|
|
||||||
assert name == "task_agent"
|
|
||||||
assert inst is agent
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_classify_called_with_message_and_context():
|
|
||||||
agent = _FixedAgent("note_agent")
|
|
||||||
reg = _make_registry("note_agent", agent)
|
|
||||||
ctx = {"some": "context"}
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify:
|
|
||||||
await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg)
|
|
||||||
|
|
||||||
mock_classify.assert_awaited_once()
|
|
||||||
call_args = mock_classify.call_args
|
|
||||||
assert call_args[0][0] == "take a note"
|
|
||||||
assert call_args[0][1] == ctx
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_uses_default_registry_when_none():
|
|
||||||
agent = _FixedAgent("task_agent")
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
|
|
||||||
patch("app.core.orchestrator._default_registry") as mock_reg:
|
|
||||||
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
|
|
||||||
mock_reg.get.return_value = agent
|
|
||||||
name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={})
|
|
||||||
|
|
||||||
assert name == "task_agent"
|
|
||||||
assert inst is agent
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_get_called_with_agent_name():
|
|
||||||
agent = _FixedAgent("timeline_agent")
|
|
||||||
reg = _make_registry("timeline_agent", agent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="timeline_agent")):
|
|
||||||
await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg)
|
|
||||||
|
|
||||||
reg.get.assert_called_once_with("timeline_agent")
|
|
||||||
|
|
||||||
|
|
||||||
# ── orchestrate_v3_stream ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _collect(gen) -> list[tuple[str, str]]:
|
|
||||||
results: list[tuple[str, str]] = []
|
|
||||||
async for item in gen:
|
|
||||||
results.append(item)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_stream_first_yield_is_domain_signal():
|
|
||||||
agent = _FixedAgent("task_agent", tokens=["token1"])
|
|
||||||
reg = _make_registry("task_agent", agent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
|
||||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
|
||||||
results = await _collect(gen)
|
|
||||||
|
|
||||||
# First item must be (agent_name, "") — domain signal
|
|
||||||
assert results[0] == ("task_agent", "")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_stream_yields_agent_name_with_tokens():
|
|
||||||
agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"])
|
|
||||||
reg = _make_registry("task_agent", agent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
|
||||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
|
||||||
results = await _collect(gen)
|
|
||||||
|
|
||||||
# All items are (agent_name, token) pairs
|
|
||||||
assert all(name == "task_agent" for name, _ in results)
|
|
||||||
tokens = [tok for _, tok in results]
|
|
||||||
assert tokens[0] == "" # domain signal
|
|
||||||
assert tokens[1:] == ["Hello", " ", "world"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_stream_different_agent():
|
|
||||||
agent = _FixedAgent("note_agent", tokens=["note"])
|
|
||||||
reg = _make_registry("note_agent", agent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")):
|
|
||||||
gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg)
|
|
||||||
results = await _collect(gen)
|
|
||||||
|
|
||||||
assert results[0] == ("note_agent", "")
|
|
||||||
assert ("note_agent", "note") in results
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_stream_uses_default_registry_when_none():
|
|
||||||
agent = _FixedAgent("task_agent", tokens=["x"])
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
|
|
||||||
patch("app.core.orchestrator._default_registry") as mock_reg:
|
|
||||||
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
|
|
||||||
mock_reg.get.return_value = agent
|
|
||||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={})
|
|
||||||
results = await _collect(gen)
|
|
||||||
|
|
||||||
assert results[0][0] == "task_agent"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_stream_empty_token_list():
|
|
||||||
"""Agent with no tokens still emits the domain signal."""
|
|
||||||
|
|
||||||
class _EmptyAgent(_FixedAgent):
|
|
||||||
async def handle_stream(self, query: str, context: dict[str, Any]):
|
|
||||||
return
|
|
||||||
yield # makes it a generator
|
|
||||||
|
|
||||||
agent = _EmptyAgent("task_agent", tokens=[])
|
|
||||||
reg = _make_registry("task_agent", agent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
|
||||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
|
||||||
results = await _collect(gen)
|
|
||||||
|
|
||||||
assert results == [("task_agent", "")] # only domain signal
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_orchestrate_v3_stream_full_text_correct():
|
|
||||||
"""Concatenating all non-domain tokens reconstructs the full response."""
|
|
||||||
tokens = ["The", " ", "task", " ", "is", " ", "done."]
|
|
||||||
agent = _FixedAgent("task_agent", tokens=tokens)
|
|
||||||
reg = _make_registry("task_agent", agent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
|
||||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
|
||||||
results = await _collect(gen)
|
|
||||||
|
|
||||||
text = "".join(tok for _, tok in results[1:]) # skip domain signal
|
|
||||||
assert text == "The task is done."
|
|
||||||
|
|
||||||
|
|
||||||
# ── handle_stream default implementation ─────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_stream_default_yields_full_response():
|
|
||||||
"""Default handle_stream yields handle() result as a single chunk."""
|
|
||||||
|
|
||||||
class _SimpleAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "_simple"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return "simple response"
|
|
||||||
|
|
||||||
agent = _SimpleAgent()
|
|
||||||
tokens = [tok async for tok in agent.handle_stream("q", {})]
|
|
||||||
assert tokens == ["simple response"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_stream_override_used_by_stream():
|
|
||||||
"""_FixedAgent.handle_stream override yields individual tokens."""
|
|
||||||
agent = _FixedAgent("t", tokens=["a", "b", "c"])
|
|
||||||
tokens = [tok async for tok in agent.handle_stream("q", {})]
|
|
||||||
assert tokens == ["a", "b", "c"]
|
|
||||||
@@ -1,195 +1,82 @@
|
|||||||
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
"""Tests for app.core.output_formatter.StreamFormatter."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.schemas import (
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
WsFloatingDomain,
|
|
||||||
WsStreamBlock,
|
|
||||||
WsStreamEnd,
|
|
||||||
WsStreamStart,
|
|
||||||
WsStreamText,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
async def _stream(*events: tuple[str, object]):
|
||||||
|
for event in events:
|
||||||
async def _stream(*pairs: tuple[str, str]):
|
yield event
|
||||||
"""Async generator that yields (agent_name, token) pairs."""
|
|
||||||
for pair in pairs:
|
|
||||||
yield pair
|
|
||||||
|
|
||||||
|
|
||||||
async def collect(formatter, token_stream):
|
async def _collect(formatter: StreamFormatter, event_stream):
|
||||||
frames = []
|
frames = []
|
||||||
async for frame in formatter.format(token_stream):
|
async for frame in formatter.format(event_stream):
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_text_block():
|
async def test_stream_formatter_text_stream() -> None:
|
||||||
req_id = "req-1"
|
formatter = StreamFormatter(request_id="req-1")
|
||||||
tokens = [
|
frames = await _collect(
|
||||||
("task_agent", '{"type": "text", "content": "Hello world"}'),
|
formatter,
|
||||||
]
|
_stream(("token", "Hello"), ("token", " world")),
|
||||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
||||||
frames = await collect(formatter, _stream(*tokens))
|
|
||||||
|
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
|
||||||
assert frames[0].request_id == req_id
|
|
||||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
|
||||||
assert any("Hello world" in f.chunk for f in text_frames)
|
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_chart_block():
|
|
||||||
req_id = "req-2"
|
|
||||||
chart_json = (
|
|
||||||
'{"type": "chart", "chartType": "bar", '
|
|
||||||
'"title": "Tasks", "data": [{"x": 1}], '
|
|
||||||
'"config": {"x": {"label": "X", "color": "#fff"}}}'
|
|
||||||
)
|
)
|
||||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
||||||
frames = await collect(formatter, _stream(("task_agent", chart_json)))
|
|
||||||
|
|
||||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
||||||
assert len(block_frames) == 1
|
|
||||||
assert block_frames[0].block_type == "chart"
|
|
||||||
assert block_frames[0].data["chartType"] == "bar"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_invalid_chart_skipped():
|
|
||||||
req_id = "req-3"
|
|
||||||
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
|
|
||||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
||||||
frames = await collect(formatter, _stream(("task_agent", bad_chart)))
|
|
||||||
|
|
||||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
||||||
assert len(block_frames) == 0 # invalid chart skipped
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_entity_ref_resolved():
|
|
||||||
req_id = "req-4"
|
|
||||||
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
|
|
||||||
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
|
||||||
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
|
|
||||||
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
|
||||||
|
|
||||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
||||||
assert len(block_frames) == 1
|
|
||||||
assert block_frames[0].data["entity"] == "task"
|
|
||||||
assert block_frames[0].data["items"][0]["id"] == "t1"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_entity_ref_missing_skipped():
|
|
||||||
req_id = "req-5"
|
|
||||||
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
|
||||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
||||||
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
|
||||||
|
|
||||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
||||||
assert len(block_frames) == 0 # no tool results → skipped
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_table_block():
|
|
||||||
req_id = "req-6"
|
|
||||||
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
|
|
||||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
||||||
frames = await collect(formatter, _stream(("task_agent", table_json)))
|
|
||||||
|
|
||||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
||||||
assert len(block_frames) == 1
|
|
||||||
assert block_frames[0].block_type == "table"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_timeline_block():
|
|
||||||
req_id = "req-7"
|
|
||||||
timeline_json = '{"type": "timeline", "timelines": [{"id": "c1", "title": "M1", "date": 123}]}'
|
|
||||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
||||||
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
|
|
||||||
|
|
||||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
|
||||||
assert len(block_frames) == 1
|
|
||||||
assert block_frames[0].block_type == "timeline"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_frame_order():
|
|
||||||
"""stream_start is first, stream_end is last."""
|
|
||||||
req_id = "req-8"
|
|
||||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
|
||||||
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
|
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert isinstance(frames[1], WsStreamText)
|
||||||
|
assert frames[1].chunk == "Hello"
|
||||||
|
assert isinstance(frames[2], WsStreamText)
|
||||||
|
assert frames[2].chunk == " world"
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
# ── FloatingFormatter ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_floating_formatter_domain_emitted_first():
|
async def test_stream_formatter_floating_domain_first() -> None:
|
||||||
req_id = "pop-1"
|
formatter = StreamFormatter(request_id="req-2")
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
frames = await _collect(
|
||||||
tokens = [
|
formatter,
|
||||||
("task_agent", ""), # domain signal
|
_stream(
|
||||||
("task_agent", "Hello"),
|
(
|
||||||
("task_agent", " there"),
|
"floating_domain",
|
||||||
]
|
{"type": "node", "id": "n-1", "section": None},
|
||||||
frames = await collect(formatter, _stream(*tokens))
|
),
|
||||||
|
("token", "Summary"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
assert frames[0].domain == "tasks"
|
assert frames[0].domain.type == "node"
|
||||||
assert frames[0].request_id == req_id
|
assert frames[0].domain.id == "n-1"
|
||||||
|
assert isinstance(frames[1], WsStreamStart)
|
||||||
|
assert isinstance(frames[2], WsStreamText)
|
||||||
|
assert frames[2].chunk == "Summary"
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_floating_formatter_text_only():
|
async def test_stream_formatter_ignores_unknown_events() -> None:
|
||||||
req_id = "pop-2"
|
formatter = StreamFormatter(request_id="req-3")
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
frames = await _collect(
|
||||||
tokens = [("timeline_agent", ""), ("timeline_agent", "Summary")]
|
formatter,
|
||||||
frames = await collect(formatter, _stream(*tokens))
|
_stream(("tool_end", {"name": "x"}), ("token", "ok")),
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
|
||||||
assert frames[0].domain == "timelines"
|
|
||||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
assert len(text_frames) == 1
|
assert len(text_frames) == 1
|
||||||
assert text_frames[0].chunk == "Summary"
|
assert text_frames[0].chunk == "ok"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_floating_formatter_no_block_frames():
|
async def test_stream_formatter_empty_stream_still_brackets() -> None:
|
||||||
"""FloatingFormatter must never emit WsStreamBlock."""
|
formatter = StreamFormatter(request_id="req-4")
|
||||||
req_id = "pop-3"
|
frames = await _collect(formatter, _stream())
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
tokens = [
|
|
||||||
("note_agent", ""),
|
|
||||||
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
|
|
||||||
]
|
|
||||||
frames = await collect(formatter, _stream(*tokens))
|
|
||||||
assert not any(isinstance(f, WsStreamBlock) for f in frames)
|
|
||||||
|
|
||||||
|
assert len(frames) == 2
|
||||||
@pytest.mark.asyncio
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
async def test_floating_formatter_end_frame():
|
assert isinstance(frames[1], WsStreamEnd)
|
||||||
req_id = "pop-4"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
|
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_floating_formatter_unknown_agent_defaults_to_tasks():
|
|
||||||
req_id = "pop-5"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
|
|
||||||
assert frames[0].domain == "tasks"
|
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import pytest
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
|
WsDomain,
|
||||||
WsFrameType,
|
WsFrameType,
|
||||||
WsHomeRequest,
|
WsHomeRequest,
|
||||||
WsFloatingDomain,
|
WsFloatingDomain,
|
||||||
WsFloatingRequest,
|
WsFloatingRequest,
|
||||||
WsFloatingScope,
|
WsFloatingScope,
|
||||||
WsStreamBlock,
|
|
||||||
WsStreamEnd,
|
WsStreamEnd,
|
||||||
WsStreamStart,
|
WsStreamStart,
|
||||||
WsStreamText,
|
WsStreamText,
|
||||||
@@ -25,7 +25,6 @@ def test_v3_frame_types_exist():
|
|||||||
"floating_request",
|
"floating_request",
|
||||||
"stream_start",
|
"stream_start",
|
||||||
"stream_text",
|
"stream_text",
|
||||||
"stream_block",
|
|
||||||
"stream_end",
|
"stream_end",
|
||||||
"floating_domain",
|
"floating_domain",
|
||||||
"data_request",
|
"data_request",
|
||||||
@@ -174,89 +173,21 @@ def test_stream_text_deserializes():
|
|||||||
assert frame.chunk == "test"
|
assert frame.chunk == "test"
|
||||||
|
|
||||||
|
|
||||||
# ── WsStreamBlock ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_block_chart():
|
|
||||||
data = {
|
|
||||||
"type": "chart",
|
|
||||||
"chartType": "bar",
|
|
||||||
"title": "Tasks",
|
|
||||||
"data": [{"name": "Done", "count": 5}],
|
|
||||||
"config": {"count": {"label": "Count", "color": "#4f46e5"}},
|
|
||||||
}
|
|
||||||
frame = WsStreamBlock(request_id="r1", block_type="chart", data=data)
|
|
||||||
assert frame.type == WsFrameType.stream_block
|
|
||||||
assert frame.block_type == "chart"
|
|
||||||
assert frame.data["chartType"] == "bar"
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_block_entity_ref():
|
|
||||||
frame = WsStreamBlock(
|
|
||||||
request_id="r1",
|
|
||||||
block_type="entity_ref",
|
|
||||||
data={"type": "task", "id": "t-1", "title": "Fix bug"},
|
|
||||||
)
|
|
||||||
assert frame.block_type == "entity_ref"
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_block_table():
|
|
||||||
frame = WsStreamBlock(
|
|
||||||
request_id="r1",
|
|
||||||
block_type="table",
|
|
||||||
data={"headers": ["A", "B"], "rows": [["1", "2"]]},
|
|
||||||
)
|
|
||||||
assert frame.block_type == "table"
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_block_timeline():
|
|
||||||
frame = WsStreamBlock(
|
|
||||||
request_id="r1",
|
|
||||||
block_type="timeline",
|
|
||||||
data={"timelines": [{"id": "c1", "title": "Launch", "date": 1700000000}]},
|
|
||||||
)
|
|
||||||
assert frame.block_type == "timeline"
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_block_invalid_type():
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
WsStreamBlock(
|
|
||||||
request_id="r1",
|
|
||||||
block_type="unknown", # type: ignore[arg-type]
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_block_serializes():
|
|
||||||
frame = WsStreamBlock(request_id="r1", block_type="table", data={"headers": [], "rows": []})
|
|
||||||
d = frame.model_dump()
|
|
||||||
assert d["type"] == "stream_block"
|
|
||||||
assert d["block_type"] == "table"
|
|
||||||
|
|
||||||
|
|
||||||
# ── WsStreamEnd ───────────────────────────────────────────────────────
|
# ── WsStreamEnd ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_defaults():
|
def test_stream_end_defaults():
|
||||||
frame = WsStreamEnd(request_id="r1")
|
frame = WsStreamEnd(request_id="r1")
|
||||||
assert frame.type == WsFrameType.stream_end
|
assert frame.type == WsFrameType.stream_end
|
||||||
assert frame.mutations == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_with_mutations():
|
|
||||||
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
|
||||||
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
|
||||||
assert len(frame.mutations) == 1
|
|
||||||
assert frame.mutations[0]["action"] == "create"
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_serializes():
|
def test_stream_end_serializes():
|
||||||
data = WsStreamEnd(request_id="r2").model_dump()
|
data = WsStreamEnd(request_id="r2").model_dump()
|
||||||
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
assert data == {"type": "stream_end", "request_id": "r2"}
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_deserializes():
|
def test_stream_end_deserializes():
|
||||||
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
raw = {"type": "stream_end", "request_id": "r3"}
|
||||||
frame = WsStreamEnd.model_validate(raw)
|
frame = WsStreamEnd.model_validate(raw)
|
||||||
assert frame.request_id == "r3"
|
assert frame.request_id == "r3"
|
||||||
|
|
||||||
@@ -265,28 +196,47 @@ def test_stream_end_deserializes():
|
|||||||
|
|
||||||
|
|
||||||
def test_floating_domain_tasks():
|
def test_floating_domain_tasks():
|
||||||
frame = WsFloatingDomain(request_id="r1", domain="tasks")
|
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
|
||||||
assert frame.type == WsFrameType.floating_domain
|
assert frame.type == WsFrameType.floating_domain
|
||||||
assert frame.domain == "tasks"
|
assert frame.domain.type == "task"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
|
def test_floating_domain_valid_domains():
|
||||||
def test_floating_domain_valid_domains(domain: str):
|
frame = WsFloatingDomain(
|
||||||
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
request_id="r1",
|
||||||
assert frame.domain == domain
|
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
|
||||||
|
)
|
||||||
|
assert frame.domain.type == "project"
|
||||||
|
assert frame.domain.id == "213213-312321-312312-421321"
|
||||||
|
assert frame.domain.section == "task"
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_invalid():
|
def test_floating_domain_object_valid():
|
||||||
with pytest.raises(ValidationError):
|
frame = WsFloatingDomain(
|
||||||
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
request_id="r1",
|
||||||
|
domain=WsDomain(type="project", id="p1", section="task"),
|
||||||
|
)
|
||||||
|
assert frame.domain.type == "project"
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_serializes():
|
def test_floating_domain_serializes():
|
||||||
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
|
d = WsFloatingDomain(
|
||||||
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
|
request_id="r1",
|
||||||
|
domain=WsDomain(type="timeline"),
|
||||||
|
).model_dump()
|
||||||
|
assert d == {
|
||||||
|
"type": "floating_domain",
|
||||||
|
"request_id": "r1",
|
||||||
|
"domain": {"type": "timeline", "id": None, "section": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_deserializes():
|
def test_floating_domain_deserializes():
|
||||||
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
|
raw = {
|
||||||
|
"type": "floating_domain",
|
||||||
|
"request_id": "r1",
|
||||||
|
"domain": {"type": "node", "id": "n-1", "section": None},
|
||||||
|
}
|
||||||
frame = WsFloatingDomain.model_validate(raw)
|
frame = WsFloatingDomain.model_validate(raw)
|
||||||
assert frame.domain == "projects"
|
assert frame.domain.type == "node"
|
||||||
|
assert frame.domain.id == "n-1"
|
||||||
|
|||||||
@@ -45,14 +45,13 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
async def _mock_home_stream(user_id, message, context, reg=None):
|
async def _mock_home_stream(user_id, message, context):
|
||||||
yield "task_agent", ""
|
yield "token", "Hello"
|
||||||
yield "task_agent", '{"type": "text", "content": "Hello"}'
|
|
||||||
|
|
||||||
|
|
||||||
async def _mock_floating_stream(user_id, message, context, reg=None):
|
async def _mock_floating_stream(user_id, message, context):
|
||||||
yield "task_agent", ""
|
yield "floating_domain", {"type": "task", "id": None, "section": None}
|
||||||
yield "task_agent", "Here is a summary"
|
yield "token", "Here is a summary"
|
||||||
|
|
||||||
|
|
||||||
# ── tests ─────────────────────────────────────────────────────────────────────
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||||
@@ -61,7 +60,7 @@ def test_home_request_produces_stream_frames(client):
|
|||||||
"""home_request → stream_start, stream_text+, stream_end."""
|
"""home_request → stream_start, stream_text+, stream_end."""
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream):
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
|
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
|
||||||
@@ -84,7 +83,7 @@ def test_floating_request_produces_domain_frame(client):
|
|||||||
"""floating_request → floating_domain first, then stream_text*, stream_end."""
|
"""floating_request → floating_domain first, then stream_text*, stream_end."""
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_floating_stream):
|
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
||||||
@@ -103,7 +102,7 @@ def test_floating_request_produces_domain_frame(client):
|
|||||||
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
||||||
assert domain_frame["domain"] == "tasks"
|
assert domain_frame["domain"]["type"] == "task"
|
||||||
assert domain_frame["request_id"] == "p1"
|
assert domain_frame["request_id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
@@ -112,11 +111,10 @@ def test_home_request_request_id_propagated(client):
|
|||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
req_id = "my-unique-req-id"
|
req_id = "my-unique-req-id"
|
||||||
|
|
||||||
async def _stream(user_id, message, context, reg=None):
|
async def _stream(user_id, message, context):
|
||||||
yield "note_agent", ""
|
yield "token", "ok"
|
||||||
yield "note_agent", '{"type": "text", "content": "ok"}'
|
|
||||||
|
|
||||||
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream):
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
|
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
|
||||||
|
|||||||
Reference in New Issue
Block a user