Compare commits
66 Commits
feature/de
...
d5fea95561
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5fea95561 | ||
|
|
0b5ef48463 | ||
|
|
ca8721e1ac | ||
|
|
f658e5e6a3 | ||
|
|
341ee140e5 | ||
|
|
741b9b87fb | ||
|
|
2d8abb6311 | ||
|
|
e668e3fd20 | ||
|
|
7ccdad431f | ||
|
|
4073863dc6 | ||
|
|
a85f8fde29 | ||
|
|
90500a3462 | ||
|
|
c1a8ac7669 | ||
|
|
c510cbaae5 | ||
|
|
ce139bbac3 | ||
|
|
3cf067faea | ||
|
|
7253f6fe72 | ||
|
|
41db3a7089 | ||
|
|
cc94194fd1 | ||
|
|
96c91e386d | ||
|
|
c0aef71141 | ||
|
|
467abc8d42 | ||
|
|
5753f8def9 | ||
|
|
e672b58b6f | ||
|
|
d8add7e8cb | ||
|
|
c6c4578f9a | ||
|
|
3aa0b36a6c | ||
|
|
fa231a3642 | ||
|
|
d91c98f86d | ||
|
|
c0619f5c4d | ||
|
|
da282229ff | ||
|
|
7fa6ad5760 | ||
|
|
dcd14220ca | ||
|
|
3cc32569d9 | ||
|
|
bf445ac2ce | ||
|
|
a2d6d689e4 | ||
|
|
aa8bcbf0d8 | ||
|
|
1ce1d492b0 | ||
|
|
552b8eb305 | ||
|
|
0d93b3960d | ||
|
|
f07580574b | ||
|
|
1a8bf11f90 | ||
|
|
e7cdce8287 | ||
|
|
58bc6efd4b | ||
|
|
6c450805cb | ||
|
|
f340d0fa3e | ||
|
|
edc53cb6eb | ||
|
|
725cece5c1 | ||
|
|
297e20ce8d | ||
|
|
5a03bd1cfb | ||
|
|
87b7a1c6c9 | ||
|
|
826f64d6bb | ||
| 5faa6b1d7c | |||
| 02a9684cd6 | |||
| fae9efee0d | |||
| 30b062dd4a | |||
| 2a0331d7ce | |||
| 13fd8677c1 | |||
| 9bd629cb59 | |||
| 9c97702daa | |||
| a1e364c9c0 | |||
| 5b55f1292a | |||
| 5bc9ea6cd6 | |||
| f7404b6f66 | |||
| d667e43c73 | |||
| fe085a7951 |
79
.env.example
79
.env.example
@@ -2,7 +2,7 @@
|
|||||||
ENV=dev
|
ENV=dev
|
||||||
|
|
||||||
# ── Database ──────────────────────────────────────────────────────────────────
|
# ── Database ──────────────────────────────────────────────────────────────────
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
|
||||||
|
|
||||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||||
JWT_SECRET=replace-with-a-long-random-secret
|
JWT_SECRET=replace-with-a-long-random-secret
|
||||||
@@ -13,31 +13,76 @@ JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
|||||||
# ── LLM ───────────────────────────────────────────────────────────────────────
|
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||||
# LiteLLM model identifiers — change to swap providers without code changes.
|
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||||
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||||
|
#
|
||||||
|
# API keys — only the key(s) matching your chosen provider(s) are required.
|
||||||
|
# The correct key is picked automatically from the model prefix (e.g.
|
||||||
|
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
LLM_MODEL=gpt-4o
|
CEREBRAS_API_KEY=
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
# Default model used by any agent that does not have a specific override below.
|
||||||
|
LLM_MODEL=gpt-5-mini
|
||||||
|
LLM_EMBED_MODEL=text-embedding-3-small
|
||||||
|
|
||||||
|
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
|
||||||
|
# In Docker, point this to a named-volume path so tokens survive restarts.
|
||||||
|
# GITHUB_COPILOT_TOKEN_DIR=
|
||||||
|
|
||||||
|
# ── Per-agent model overrides ─────────────────────────────────────────────────
|
||||||
|
# Leave a value empty to fall back to LLM_MODEL.
|
||||||
|
# Each agent resolves its API key from the model prefix automatically.
|
||||||
|
#
|
||||||
|
# Intent classifier — routes user messages to the right domain agent.
|
||||||
|
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
|
||||||
|
LLM_MODEL_CLASSIFIER=
|
||||||
|
|
||||||
|
# Home-agent — handles chat from the home screen (all tools available).
|
||||||
|
LLM_MODEL_HOME_AGENT=
|
||||||
|
|
||||||
|
# Floating-agent — handles contextual chat triggered from a task/project/note.
|
||||||
|
LLM_MODEL_FLOATING_AGENT=
|
||||||
|
|
||||||
|
# Unified-processor — processes local directory files (local agent runner).
|
||||||
|
LLM_MODEL_UNIFIED_PROCESSOR=
|
||||||
|
|
||||||
|
# Cloud-processor — fetches and processes data from cloud connectors.
|
||||||
|
LLM_MODEL_CLOUD_PROCESSOR=
|
||||||
|
|
||||||
|
# Brief-agent — produces home and project text briefs.
|
||||||
|
# A small model (e.g. gpt-4o-mini) is sufficient.
|
||||||
|
# LLM_MODEL_BRIEF_AGENT=
|
||||||
|
|
||||||
|
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||||
|
LLM_MODEL_SETUP_AGENT=
|
||||||
|
|
||||||
|
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
|
||||||
|
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
|
||||||
|
LLM_MODEL_MEMORY_EXTRACTOR=
|
||||||
|
|
||||||
|
# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only).
|
||||||
|
# Defaults to gpt-4o-mini when empty.
|
||||||
|
LLM_MODEL_MEMORY_MINER=
|
||||||
|
|
||||||
|
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
|
||||||
|
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
|
||||||
|
LLM_MODEL_MEMORY_AUDITOR=
|
||||||
|
|
||||||
|
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
|
||||||
|
SCHEDULER_ENABLED=true
|
||||||
|
|
||||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
|
||||||
S3_BUCKET=adiuva
|
|
||||||
S3_REGION=us-east-1
|
|
||||||
S3_ENDPOINT_URL=
|
|
||||||
AWS_ACCESS_KEY_ID=
|
|
||||||
AWS_SECRET_ACCESS_KEY=
|
|
||||||
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
|
|
||||||
|
|
||||||
# ── Vector Store ──────────────────────────────────────────────────────────────
|
# ── Langfuse (leave empty to disable observability) ───────────────────────────
|
||||||
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
|
LANGFUSE_SECRET_KEY=
|
||||||
PINECONE_API_KEY=
|
LANGFUSE_PUBLIC_KEY=
|
||||||
PINECONE_INDEX=adiuva
|
# LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default)
|
||||||
QDRANT_URL=
|
# LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US
|
||||||
QDRANT_API_KEY=
|
# LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted
|
||||||
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
|
||||||
|
|
||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
|
|||||||
@@ -48,23 +48,23 @@ jobs:
|
|||||||
key: ${{ secrets.SSH_KEY }}
|
key: ${{ secrets.SSH_KEY }}
|
||||||
script: |
|
script: |
|
||||||
set -e
|
set -e
|
||||||
DEPLOY_DIR="/opt/adiuva-api"
|
DEPLOY_DIR="/opt/adiuvai-api"
|
||||||
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||||
TAG="${{ gitea.ref_name }}"
|
TAG="${{ gitea.ref_name }}"
|
||||||
|
|
||||||
# ── Pull latest code ──
|
# ── Pull latest code ──
|
||||||
cd /tmp && rm -rf adiuva-api-deploy
|
cd /tmp && rm -rf adiuvai-api-deploy
|
||||||
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
|
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
|
||||||
|
|
||||||
# ── Sync source (preserve .env) ──
|
# ── Sync source (preserve .env) ──
|
||||||
cp -rf /tmp/adiuva-api-deploy/app/ \
|
cp -rf /tmp/adiuvai-api-deploy/app/ \
|
||||||
/tmp/adiuva-api-deploy/alembic/ \
|
/tmp/adiuvai-api-deploy/alembic/ \
|
||||||
/tmp/adiuva-api-deploy/alembic.ini \
|
/tmp/adiuvai-api-deploy/alembic.ini \
|
||||||
/tmp/adiuva-api-deploy/Dockerfile \
|
/tmp/adiuvai-api-deploy/Dockerfile \
|
||||||
/tmp/adiuva-api-deploy/docker-compose.yml \
|
/tmp/adiuvai-api-deploy/docker-compose.yml \
|
||||||
/tmp/adiuva-api-deploy/requirements.txt \
|
/tmp/adiuvai-api-deploy/requirements.txt \
|
||||||
"$DEPLOY_DIR/"
|
"$DEPLOY_DIR/"
|
||||||
rm -rf /tmp/adiuva-api-deploy
|
rm -rf /tmp/adiuvai-api-deploy
|
||||||
|
|
||||||
# ── Verify .env ──
|
# ── Verify .env ──
|
||||||
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||||
|
|||||||
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Build image
|
- name: Build image
|
||||||
run: docker build -t adiuva-api:ci .
|
run: docker build -t adiuvai-api:ci .
|
||||||
|
|
||||||
- name: Verify gunicorn installed
|
- name: Verify gunicorn installed
|
||||||
run: docker run --rm adiuva-api:ci gunicorn --version
|
run: docker run --rm adiuvai-api:ci gunicorn --version
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -21,12 +21,16 @@ env/
|
|||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.coverage
|
.coverage
|
||||||
|
tests/fixtures/private*/
|
||||||
|
|
||||||
# Docker
|
# Docker
|
||||||
*.log
|
*.log
|
||||||
|
|
||||||
# OS
|
# OS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# Smoke scripts (dev-only, not for CI)
|
||||||
|
scripts/smoke_*.py
|
||||||
Thumbs.db
|
Thumbs.db
|
||||||
|
|
||||||
# Claude Code
|
# Claude Code
|
||||||
|
|||||||
@@ -1,523 +0,0 @@
|
|||||||
# AI Refactor Plan — Adiuva Backend
|
|
||||||
|
|
||||||
> **Objective:** Transform backend tools from JSON-action-descriptor-returning functions into real bidirectional executors. Each tool sends structured CRUD operations to the Electron client via WebSocket, receives real data back, and returns meaningful results to the LLM. The LLM reasons about actual user data instead of serialized action payloads.
|
|
||||||
>
|
|
||||||
> **Electron app:** Lives at `../adiuva/`. See `../adiuva/AI_REFACTOR_PLAN.md`.
|
|
||||||
>
|
|
||||||
> **Protocol:** Execute steps sequentially. Each step is atomic and committable. Mark `[x]` when done.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture — Before vs After
|
|
||||||
|
|
||||||
### Before (current)
|
|
||||||
```
|
|
||||||
LLM calls list_tasks(status="todo")
|
|
||||||
→ tool returns: '{"action":"list","table":"tasks","filters":{"status":"todo"}}'
|
|
||||||
→ _tool_loop feeds that JSON string as ToolMessage to LLM
|
|
||||||
→ LLM sees a descriptor, NOT real data — cannot reason about tasks
|
|
||||||
→ Final response: generic "Here are your tasks" (no actual task data)
|
|
||||||
→ Action descriptors sent in final WS frame for Electron to execute post-response
|
|
||||||
```
|
|
||||||
|
|
||||||
### After (target)
|
|
||||||
```
|
|
||||||
LLM calls list_tasks(status="todo")
|
|
||||||
→ tool calls execute_on_client(action="select", table="tasks", filters={status:"todo"})
|
|
||||||
→ WS frame sent to Electron: {type:"tool_call", id:"abc", action:"select", table:"tasks", filters:{status:"todo"}}
|
|
||||||
→ Electron runs: db.select().from(tasks).where(eq(tasks.status, "todo")).all()
|
|
||||||
→ WS frame back: {type:"tool_result", id:"abc", rows:[{id:"1",title:"Buy milk",...}, ...]}
|
|
||||||
→ tool returns: "Found 3 tasks: 1. Buy milk (high, due tomorrow) 2. ..."
|
|
||||||
→ _tool_loop feeds that as ToolMessage to LLM
|
|
||||||
→ LLM sees REAL data — can reason, count, compare, summarize
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## WS Protocol — Typed Frames
|
|
||||||
|
|
||||||
| Direction | `type` | Payload |
|
|
||||||
|---|---|---|
|
|
||||||
| Client → Server | `chat_request` | `{ message: str, context: ChatContext }` |
|
|
||||||
| Server → Client | `text_chunk` | `{ text: str }` |
|
|
||||||
| Server → Client | `tool_call` | `{ id: str, action: str, table?: str, data?: dict, filters?: dict, vector?: list[float], limit?: int }` |
|
|
||||||
| Client → Server | `tool_result` | `{ id: str, row?: dict, rows?: list[dict], results?: list[dict], deleted?: bool, ok?: bool, error?: str }` |
|
|
||||||
| Server → Client | `final` | `{ response: str }` |
|
|
||||||
| Server → Client | `ping` | `{}` |
|
|
||||||
|
|
||||||
**Actions:**
|
|
||||||
|
|
||||||
| `action` | What Electron does (Drizzle) | `tool_result` shape |
|
|
||||||
|---|---|---|
|
|
||||||
| `select` | `db.select().from(table).where(filters)` | `{ rows: [...] }` |
|
|
||||||
| `get` | `db.select().from(table).where(id=...).get()` | `{ row: {...} or null }` |
|
|
||||||
| `insert` | `db.insert(table).values({id: uuid(), ...data}).returning().get()` | `{ row: {...} }` |
|
|
||||||
| `update` | `db.update(table).set(updates).where(id=...).returning().get()` | `{ row: {...} }` |
|
|
||||||
| `delete` | `db.delete(table).where(id=...).run()` | `{ deleted: true }` |
|
|
||||||
| `vector_upsert` | LanceDB upsert with pre-computed vector | `{ ok: true }` |
|
|
||||||
| `vector_search` | LanceDB search by vector | `{ results: [{id, content, score}...] }` |
|
|
||||||
|
|
||||||
**Electron generates IDs + timestamps.** Backend tools never send `id` or `createdAt` in `insert` data — Electron adds `id: uuid()`, `createdAt: Date.now()`, `updatedAt: Date.now()`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## SQLite Schema Reference (Electron's local database)
|
|
||||||
|
|
||||||
Tools must use **camelCase** field names (Drizzle maps them to snake_case internally):
|
|
||||||
|
|
||||||
| Table | Columns |
|
|
||||||
|---|---|
|
|
||||||
| `tasks` | id, projectId, title, description, status (todo\|in_progress\|done), priority (high\|medium\|low), assignee (JSON array string), dueDate (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
|
||||||
| `projects` | id, clientId, name, status (active\|archived), aiSummary, createdAt (ms) |
|
|
||||||
| `timelines` | id, projectId (required), title, date (ms), isAiSuggested (0\|1), isApproved (0\|1), createdAt (ms) |
|
|
||||||
| `notes` | id, projectId, title, content (markdown), createdAt (ms), updatedAt (ms) |
|
|
||||||
| `taskComments` | id, taskId, author, content, createdAt (ms) |
|
|
||||||
| `clients` | id, parentId, name, industry, createdAt (ms) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase B — Backend Changes
|
|
||||||
|
|
||||||
### Step B.1 — WS context + frame types
|
|
||||||
- [x] Create `app/core/ws_context.py` (~25 lines):
|
|
||||||
- `_client_executor: ContextVar[Callable]` — holds the async callback for the current WS session
|
|
||||||
- `async def execute_on_client(action, table=None, data=None, filters=None, vector=None, limit=None) -> dict`:
|
|
||||||
- Reads callback from ContextVar
|
|
||||||
- Builds `tool_call` payload: `{id: str(uuid4()), action, table, data, filters, vector, limit}` (omits None fields)
|
|
||||||
- Calls `await callback(payload)` — which sends the WS frame and waits for `tool_result`
|
|
||||||
- Returns the result dict
|
|
||||||
- `def set_client_executor(fn)` / `def clear_client_executor()` — ContextVar management
|
|
||||||
- [x] Add to `app/schemas.py`:
|
|
||||||
- `WsFrameType(str, Enum)`: `chat_request`, `text_chunk`, `tool_call`, `tool_result`, `final`, `ping`
|
|
||||||
- `WsToolCall(BaseModel)`: `type`, `id`, `action`, `table?`, `data?`, `filters?`, `vector?`, `limit?`
|
|
||||||
- `WsToolResult(BaseModel)`: `type`, `id`, `row?`, `rows?`, `results?`, `deleted?`, `ok?`, `error?`
|
|
||||||
- `WsTextChunk(BaseModel)`: `type`, `text`
|
|
||||||
- `WsFinal(BaseModel)`: `type`, `response`
|
|
||||||
- **Files:** `app/core/ws_context.py`, `app/schemas.py`
|
|
||||||
- **Outcome:** Any tool can `await execute_on_client(...)` to query/mutate the user's local DB.
|
|
||||||
|
|
||||||
### Step B.2 — Rewrite all 23 tools to use `execute_on_client()`
|
|
||||||
- [x] Each tool: same `@tool` decorator, same parameters, same docstring. Replace `return json.dumps({...})` body with:
|
|
||||||
1. Call `result = await execute_on_client(action=..., table=..., data/filters=...)`
|
|
||||||
2. Return human-readable string with confirmation + key data from `result`
|
|
||||||
|
|
||||||
- [x] **`app/agents/task_agent.py` (8 tools):**
|
|
||||||
- `list_tasks(project_id, status, search, order_by)`:
|
|
||||||
```python
|
|
||||||
result = await execute_on_client(action="select", table="tasks", filters={
|
|
||||||
"projectId": project_id or None,
|
|
||||||
"status": status or None,
|
|
||||||
"search": search or None,
|
|
||||||
"orderBy": order_by or None,
|
|
||||||
})
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No tasks found matching the given filters."
|
|
||||||
lines = [f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" for r in rows]
|
|
||||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
|
||||||
```
|
|
||||||
- `create_task(title, ...)`:
|
|
||||||
```python
|
|
||||||
result = await execute_on_client(action="insert", table="tasks", data={
|
|
||||||
"title": title, "description": description or None, "status": status,
|
|
||||||
"priority": priority, "assignee": assignees, "dueDate": due_date or None,
|
|
||||||
"projectId": project_id or None, "isAiSuggested": is_ai_suggested, "isApproved": is_approved,
|
|
||||||
})
|
|
||||||
row = result["row"]
|
|
||||||
return f"Task created: '{row['title']}' (id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
|
||||||
```
|
|
||||||
- `update_task(task_id, ...)`: build updates dict (same logic as now) → `execute_on_client(action="update", table="tasks", data={"id": task_id, "updates": updates})` → return "Task updated: {title}"
|
|
||||||
- `delete_task(task_id)`: `execute_on_client(action="delete", table="tasks", data={"id": task_id})` → return "Task deleted"
|
|
||||||
- `list_tasks_due_today()`: calculate today's start/end ms → `execute_on_client(action="select", table="tasks", filters={"dueDateFrom": start, "dueDateTo": end})` → format + return
|
|
||||||
- `list_task_comments(task_id)`: `execute_on_client(action="select", table="taskComments", filters={"taskId": task_id})` → format + return
|
|
||||||
- `add_task_comment(task_id, author, content)`: `execute_on_client(action="insert", table="taskComments", data={...})` → return confirmation
|
|
||||||
- `delete_task_comment(comment_id)`: `execute_on_client(action="delete", table="taskComments", data={"id": comment_id})` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/project_agent.py` (6 tools):**
|
|
||||||
- `list_projects(client_id, include_archived)`: `execute_on_client(action="select", table="projects", filters={clientId, includeArchived})` → format + return
|
|
||||||
- `list_all_projects()`: `execute_on_client(action="select", table="projects")` → format + return
|
|
||||||
- `get_project(project_id)`: `execute_on_client(action="get", table="projects", data={"id": project_id})` → return project details or "not found"
|
|
||||||
- `create_project(name, client_id)`: `execute_on_client(action="insert", table="projects", data={name, clientId})` → return confirmation + id
|
|
||||||
- `update_project(project_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
|
||||||
- `delete_project(project_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/timeline_agent.py` (4 tools):**
|
|
||||||
- `list_timelines(project_id)`: `execute_on_client(action="select", table="timelines", filters={projectId})` → format + return
|
|
||||||
- `create_timeline(project_id, title, date, ...)`: `execute_on_client(action="insert", table="timelines", data={...})` → return confirmation + id
|
|
||||||
- `update_timeline(timeline_id, ...)`: build updates → `execute_on_client(action="update", ...)` → return confirmation
|
|
||||||
- `delete_timeline(timeline_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- [x] **`app/agents/note_agent.py` (5 tools):**
|
|
||||||
- `list_notes(project_id)`: `execute_on_client(action="select", table="notes", filters={projectId})` → format + return
|
|
||||||
- `get_note(note_id)`: `execute_on_client(action="get", table="notes", data={"id": note_id})` → return full content or "not found"
|
|
||||||
- `create_note(title, content, project_id)`: `execute_on_client(action="insert", table="notes", data={...})` → then `execute_on_client(action="vector_upsert", data={id, projectId, content}, vector=await embed(content))` → return confirmation
|
|
||||||
- `update_note(note_id, ...)`: build updates → `execute_on_client(action="update", ...)` → then vector_upsert for updated content → return confirmation
|
|
||||||
- `delete_note(note_id)`: `execute_on_client(action="delete", ...)` → return confirmation
|
|
||||||
|
|
||||||
- **Files:** `app/agents/task_agent.py`, `app/agents/project_agent.py`, `app/agents/timeline_agent.py`, `app/agents/note_agent.py`
|
|
||||||
- **Outcome:** All 23 tools query real user data via WS. LLM sees actual rows, not action descriptors.
|
|
||||||
|
|
||||||
### Step B.3 — Bidirectional WebSocket handler
|
|
||||||
- [x] Refactor `app/api/routes/chat.py` WS endpoint:
|
|
||||||
- After auth + accept + receive `chat_request`:
|
|
||||||
1. Create `execute_on_client` callback closure capturing the websocket:
|
|
||||||
```python
|
|
||||||
pending_calls: dict[str, asyncio.Future] = {}
|
|
||||||
|
|
||||||
async def on_client_result(frame: dict):
|
|
||||||
"""Called when a tool_result frame arrives from Electron."""
|
|
||||||
fut = pending_calls.pop(frame["id"], None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(frame)
|
|
||||||
|
|
||||||
async def execute_callback(payload: dict) -> dict:
|
|
||||||
"""Send tool_call to Electron, wait for tool_result."""
|
|
||||||
call_id = payload["id"]
|
|
||||||
fut = asyncio.get_event_loop().create_future()
|
|
||||||
pending_calls[call_id] = fut
|
|
||||||
await websocket.send_text(json.dumps({"type": "tool_call", **payload}))
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
```
|
|
||||||
2. Set `client_executor` ContextVar with `execute_callback`
|
|
||||||
3. Run orchestrator in a task — it calls agents, agents call tools, tools call `execute_on_client()` which goes through the callback
|
|
||||||
4. In parallel, run a message receive loop that dispatches incoming frames:
|
|
||||||
- `tool_result` → `on_client_result(frame)`
|
|
||||||
- `ping` → ignore
|
|
||||||
5. Orchestrator yields `text_chunk` frames → send to client
|
|
||||||
6. Send `final` frame when done
|
|
||||||
7. Clear ContextVar
|
|
||||||
- Keep heartbeat ping every 30s
|
|
||||||
- 30s timeout on `tool_result` — if Electron doesn't respond, future raises `TimeoutError`, tool returns error string to LLM
|
|
||||||
- **Files:** `app/api/routes/chat.py`
|
|
||||||
- **Outcome:** Full bidirectional WS. Tool calls and text streaming happen concurrently on the same connection.
|
|
||||||
|
|
||||||
### Step B.4 — `_tool_loop` — no changes needed
|
|
||||||
- [x] Verify `app/core/agent_registry.py` works unchanged:
|
|
||||||
- `_tool_loop` calls `tool_fn.ainvoke(args)` → tool awaits `execute_on_client()` (WS round-trip) → returns string → `ToolMessage(content=string)` → LLM sees real data
|
|
||||||
- The async WS round-trip happens inside each tool. `_tool_loop` just sees an awaited tool returning a string — same as before, different content.
|
|
||||||
- **No code changes.** Just verify + add a log line for tool execution times if desired.
|
|
||||||
|
|
||||||
### Step B.5 — Orchestrator cleanup
|
|
||||||
- [x] Update `app/core/orchestrator.py`:
|
|
||||||
- `orchestrate_stream()`: remove `"actions": []` from final frame. Final becomes: `{"done": true, "response": "..."}`
|
|
||||||
- No other changes — `classify_intent` → `call_agent` → chunk response → final frame
|
|
||||||
- **Files:** `app/core/orchestrator.py`
|
|
||||||
- **Outcome:** Clean final frame. No more action descriptors in the protocol.
|
|
||||||
|
|
||||||
### Step B.6 — Add `/vectors/embed` endpoint
|
|
||||||
- [x] Add to `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/embed`:
|
|
||||||
- Request: `{ text: str }`
|
|
||||||
- Response: `{ vector: list[float] }` (1536-dim from `text-embedding-3-small`)
|
|
||||||
- Auth required (JWT)
|
|
||||||
- Used by:
|
|
||||||
- Backend tools: `note_agent` calls this before `vector_upsert`
|
|
||||||
- Electron: `vectordb.ts` calls this for note embedding on create/update
|
|
||||||
- **Files:** `app/api/routes/vectors.py`
|
|
||||||
- **Outcome:** Single embedding endpoint. Both backend tools and Electron can generate vectors.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Verification
|
|
||||||
|
|
||||||
| What to test | How |
|
|
||||||
|---|---|
|
|
||||||
| **Read flow** | "List my tasks" → `list_tasks` → `tool_call{select, tasks}` → Electron returns rows → LLM describes real tasks |
|
|
||||||
| **Write flow** | "Create a task called Buy milk" → `create_task` → `tool_call{insert, tasks, data:{title:"Buy milk"}}` → Electron inserts + returns row → tool confirms with id |
|
|
||||||
| **Multi-tool** | "How many todo tasks do I have?" → `list_tasks(status=todo)` → LLM counts actual rows → "You have 3 todo tasks" |
|
|
||||||
| **Vector search** | "Find notes about deployment" → tool embeds → `tool_call{vector_search, vector:[...]}` → Electron searches LanceDB → returns matching notes |
|
|
||||||
| **Vector upsert** | "Create a note about..." → insert note → vector_upsert with embedding → both SQLite + LanceDB updated |
|
|
||||||
| **Tool timeout** | Disconnect Electron mid-conversation → 30s timeout → tool returns error → LLM handles gracefully |
|
|
||||||
| **Concurrent calls** | Agent calls 2 tools in sequence → each does WS round-trip → both succeed → LLM sees both results |
|
|
||||||
| **_tool_loop max iter** | Verify 5-iteration limit still works → after 5 tool calls, LLM forced to answer without tools |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Execution Notes
|
|
||||||
|
|
||||||
- **Phase 1 is the critical path.** Auth + backend client + drizzle executor + orchestrator refactor must land first.
|
|
||||||
- **Steps 1.1–1.4 are additive** — existing app keeps working until Step 1.5 swaps the orchestrator.
|
|
||||||
- **Step 2.1 is the point of no return** — after removing LangChain, there's no local AI fallback.
|
|
||||||
- **Phase B (backend changes) must land before Phase 1.3–1.5** — Electron needs the bidirectional WS to talk to.
|
|
||||||
- **Phase 3 and Phase 4 are independent** — can be parallelized after Phase 2.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3 — Agent System: Config, Orchestration & Cloud Connectors
|
|
||||||
|
|
||||||
> **Objective:** Backend manages all agent configuration, scheduling, orchestration, and cloud data fetching. Two agent types: **Local Directory Agent** (backend triggers Electron to read files, then AI analyzes) and **Cloud Connector Agent** (backend fetches Gmail/Teams data directly, AI analyzes, pushes results to Electron via WS tool_call). All extracted items use existing WS tool infrastructure to insert into Electron's local DB with `is_ai_suggested=True`.
|
|
||||||
>
|
|
||||||
> **Electron Phase 3 plan:** `../adiuva/AI_REFACTOR_PLAN.md` Phase 3 section.
|
|
||||||
>
|
|
||||||
> **Electron UI status (2025):** Steps 3.6, 3.7, 3.8 of the Electron plan are ✅ complete. Agents are configured inside the Settings page (`/settings?section=agents`) — not a standalone route. The `JourneyDialog` (Step 3.8) is embedded inline in the Settings → Agents section. `LocalAgentConfigPanel` and `CloudAgentConfigPanel` (Step 3.7) are also inline. This affects the journey API contract (see Step 3.5 below).
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
Local Agent:
|
|
||||||
Scheduler/manual trigger ──► check device online ──► WS agent_run → Electron
|
|
||||||
──► Electron reads files ──► WS agent_data → Backend
|
|
||||||
──► Backend AI (prompt_template + file content) ──► WS tool_call(insert) → Electron
|
|
||||||
──► Electron persists with isAiSuggested=1
|
|
||||||
|
|
||||||
Cloud Agent:
|
|
||||||
Scheduler/manual trigger ──► Backend fetches Gmail/Teams (OAuth) ──► Backend AI analyzes
|
|
||||||
──► check device online ──► WS tool_call(insert) → Electron ──► Electron persists
|
|
||||||
```
|
|
||||||
|
|
||||||
**New WS frame types:**
|
|
||||||
|
|
||||||
| Direction | `type` | Payload |
|
|
||||||
|---|---|---|
|
|
||||||
| Server → Client | `agent_run` | `{ run_id, agent_id, config: { paths, file_extensions, prompt_template, data_types } }` |
|
|
||||||
| Client → Server | `agent_data` | `{ run_id, files: [{ path, name, content, metadata }] }` |
|
|
||||||
| Client → Server | `agent_complete` | `{ run_id, files_read, errors }` |
|
|
||||||
| Client → Server | `device_hello` | `{ device_id, agent_ids }` |
|
|
||||||
|
|
||||||
### Step 3.1 — Agent config tables
|
|
||||||
- [x] Add to `app/models.py`:
|
|
||||||
- **`LocalAgentConfig`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `device_id` str — identifies which Electron install this config belongs to
|
|
||||||
- `name` str
|
|
||||||
- `directory_paths` JSON — list of absolute paths on the device
|
|
||||||
- `data_types` JSON — which tables to extract to: `["tasks", "notes", "timelines", "projects"]`
|
|
||||||
- `prompt_template` text — user-configured via Chatbot Journey
|
|
||||||
- `file_extensions` JSON — e.g. `[".eml", ".txt", ".pdf", ".md"]`
|
|
||||||
- `schedule_cron` str — e.g. `"0 */6 * * *"` (every 6h)
|
|
||||||
- `enabled` bool (default True)
|
|
||||||
- `last_run_at` datetime nullable
|
|
||||||
- `created_at`, `updated_at` timestamps
|
|
||||||
- **`CloudAgentConfig`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `provider` str — enum: `gmail`, `teams`, `outlook`
|
|
||||||
- `name` str
|
|
||||||
- `data_types` JSON — same format as local
|
|
||||||
- `prompt_template` text
|
|
||||||
- `oauth_token_encrypted` text — Fernet-encrypted OAuth2 credentials
|
|
||||||
- `schedule_cron` str
|
|
||||||
- `enabled` bool (default True)
|
|
||||||
- `last_run_at` datetime nullable
|
|
||||||
- `filter_config` JSON — provider-specific: `{ labels: [], date_range: {from, to}, senders: [] }`
|
|
||||||
- `created_at`, `updated_at` timestamps
|
|
||||||
- **`AgentRunLog`**:
|
|
||||||
- `id` UUID PK
|
|
||||||
- `agent_id` str — references LocalAgentConfig.id or CloudAgentConfig.id
|
|
||||||
- `agent_type` str — `local` or `cloud`
|
|
||||||
- `user_id` FK → users
|
|
||||||
- `status` str — `running`, `success`, `error`, `partial`
|
|
||||||
- `items_processed` int (default 0)
|
|
||||||
- `items_created` int (default 0)
|
|
||||||
- `errors` JSON — list of error strings
|
|
||||||
- `started_at` datetime
|
|
||||||
- `completed_at` datetime nullable
|
|
||||||
- [x] Add Pydantic schemas to `app/schemas.py`:
|
|
||||||
- `LocalAgentConfigCreate`, `LocalAgentConfigUpdate`, `LocalAgentConfigResponse`
|
|
||||||
- `CloudAgentConfigCreate`, `CloudAgentConfigUpdate`, `CloudAgentConfigResponse`
|
|
||||||
- `AgentRunLogResponse`
|
|
||||||
- `AgentCatalogItem` — `{ type, name, description, config_schema }`
|
|
||||||
- `WsAgentRun`, `WsAgentData`, `WsAgentComplete`, `WsDeviceHello`
|
|
||||||
- [x] Generate Alembic migration
|
|
||||||
- **Files:** `app/models.py`, `app/schemas.py`, `alembic/versions/`
|
|
||||||
- **Outcome:** Agent config and run tracking tables in PostgreSQL.
|
|
||||||
|
|
||||||
### Step 3.2 — Agent CRUD API routes
|
|
||||||
- [x] Create `app/api/routes/agents.py`:
|
|
||||||
- `GET /api/v1/agents/catalog` — returns hardcoded agent type catalog:
|
|
||||||
- `local_directory`: "Watches local directories, extracts data from files using AI"
|
|
||||||
- `gmail`: "Scans Gmail inbox, extracts tasks/notes from emails"
|
|
||||||
- `teams`: "Monitors Teams messages, extracts action items"
|
|
||||||
- `outlook`: "Scans Outlook inbox, extracts tasks/notes"
|
|
||||||
- `GET /api/v1/agents/local` — list user's local agent configs
|
|
||||||
- `POST /api/v1/agents/local` — create local agent config
|
|
||||||
- Body: `{ name, device_id, directory_paths, data_types, prompt_template, file_extensions, schedule_cron }`
|
|
||||||
- Tier check: count enabled agents ≤ `batch_active` limit
|
|
||||||
- `PUT /api/v1/agents/local/{id}` — update config (ownership check)
|
|
||||||
- `DELETE /api/v1/agents/local/{id}` — delete config + associated run logs
|
|
||||||
- `GET /api/v1/agents/cloud` — list user's cloud agent configs
|
|
||||||
- `POST /api/v1/agents/cloud` — create cloud connector config
|
|
||||||
- Body: `{ provider, name, data_types, prompt_template, oauth_token_encrypted, schedule_cron, filter_config }`
|
|
||||||
- Tier check: same `batch_active` limit (local + cloud count together)
|
|
||||||
- `PUT /api/v1/agents/cloud/{id}` — update config
|
|
||||||
- `DELETE /api/v1/agents/cloud/{id}` — delete config + run logs
|
|
||||||
- `GET /api/v1/agents/runs` — query params: `agent_id`, `page`, `limit` → paginated run logs
|
|
||||||
- `POST /api/v1/agents/{id}/run` — manual trigger (dispatches to agent runner)
|
|
||||||
- All routes require JWT auth; ownership enforced on all mutations
|
|
||||||
- [x] Register router in `app/main.py`
|
|
||||||
- **Files:** `app/api/routes/agents.py`, `app/main.py`
|
|
||||||
- **Outcome:** Full CRUD for agent configs with tier-gated creation limits.
|
|
||||||
|
|
||||||
### Step 3.3 — Device WS endpoint
|
|
||||||
- [x] Create `app/api/routes/device_ws.py`:
|
|
||||||
- `WebSocket /api/v1/ws/device?token=<jwt>` — persistent connection from Electron
|
|
||||||
- On connect:
|
|
||||||
- Authenticate JWT
|
|
||||||
- Receive `device_hello` frame → extract `device_id`, `agent_ids`
|
|
||||||
- Store connection in `DeviceConnectionManager` (in-memory dict: `user_id → { ws, device_id }`)
|
|
||||||
- Check for overdue agent runs → trigger them immediately
|
|
||||||
- Message loop:
|
|
||||||
- `agent_data` → route to active agent run handler
|
|
||||||
- `agent_complete` → finalize agent run
|
|
||||||
- `tool_result` → route to pending tool call (same pattern as chat WS)
|
|
||||||
- `pong` → heartbeat ack
|
|
||||||
- On disconnect:
|
|
||||||
- Remove from `DeviceConnectionManager`
|
|
||||||
- Mark any in-progress agent runs as `error` with "device disconnected"
|
|
||||||
- Heartbeat: send `ping` every 30s, disconnect if no `pong` within 10s
|
|
||||||
- [x] Create `app/core/device_manager.py`:
|
|
||||||
- `DeviceConnectionManager` (singleton):
|
|
||||||
- `register(user_id, device_id, ws)` — stores active connection
|
|
||||||
- `unregister(user_id)` — removes connection
|
|
||||||
- `get_ws(user_id) -> WebSocket | None` — returns active WS if device is online
|
|
||||||
- `is_online(user_id, device_id=None) -> bool` — optionally checks specific device
|
|
||||||
- `send_frame(user_id, frame: dict)` — sends JSON frame to device
|
|
||||||
- **Files:** `app/api/routes/device_ws.py`, `app/core/device_manager.py`, `app/main.py`
|
|
||||||
- **Outcome:** Backend maintains persistent WS connections to Electron devices for agent triggers.
|
|
||||||
|
|
||||||
### Step 3.4 — Agent run orchestrator
|
|
||||||
- [x] Create `app/core/agent_runner.py`:
|
|
||||||
- `async run_local_agent(user_id, config: LocalAgentConfig, device_mgr: DeviceConnectionManager)`:
|
|
||||||
1. Check device is online with matching `device_id` → abort if offline
|
|
||||||
2. Create `AgentRunLog` with `status=running`
|
|
||||||
3. Send `WsAgentRun` frame to Electron with config (paths, extensions, prompt)
|
|
||||||
4. Await `WsAgentData` frames — collect file contents
|
|
||||||
5. Await `WsAgentComplete` frame — Electron signals done reading
|
|
||||||
6. For each file: call LLM with `prompt_template` + file content → extract structured items
|
|
||||||
7. For each extracted item: send `WsToolCall(insert, table, data)` to Electron → await `WsToolResult`
|
|
||||||
- All inserts include `is_ai_suggested=True, is_approved=False`
|
|
||||||
8. Update `AgentRunLog`: `status=success`, `items_processed`, `items_created`
|
|
||||||
- `async run_cloud_agent(user_id, config: CloudAgentConfig, device_mgr: DeviceConnectionManager)`:
|
|
||||||
1. Check device is online → abort if offline (results must push to Electron)
|
|
||||||
2. Create `AgentRunLog` with `status=running`
|
|
||||||
3. Decrypt OAuth credentials from `config.oauth_token_encrypted`
|
|
||||||
4. Fetch data from cloud provider (Step 3.6):
|
|
||||||
- Gmail: `google-api-python-client` + `filter_config` label/date filters
|
|
||||||
- Teams: `msgraph-sdk` + channel/date filters
|
|
||||||
- Outlook: `msgraph-sdk` + folder/date filters
|
|
||||||
5. For each item: call LLM with `prompt_template` + email/message content → extract structured items
|
|
||||||
6. For each extracted item: send `WsToolCall(insert)` to Electron → await `WsToolResult`
|
|
||||||
7. Update `AgentRunLog`
|
|
||||||
- `async trigger_pending_runs(user_id, device_id, device_mgr)`:
|
|
||||||
- Called when Electron connects (after `device_hello`)
|
|
||||||
- Queries all enabled agent configs where `last_run_at + schedule_interval < now()`
|
|
||||||
- For local agents: only triggers if `config.device_id == device_id`
|
|
||||||
- For cloud agents: triggers regardless of device (any connected device can receive results)
|
|
||||||
- Executes runs sequentially (one at a time to avoid overwhelming the WS)
|
|
||||||
- Error handling: on any failure, update `AgentRunLog` with `status=error` + error details
|
|
||||||
- [x] Wire `POST /agents/{id}/run` endpoint to dispatch background task via `asyncio.create_task()`
|
|
||||||
- [x] Replace `_trigger_pending_runs_stub` in `device_ws.py` with real `trigger_pending_runs` call
|
|
||||||
- [x] Add `croniter>=3.0.0` to `requirements.txt`
|
|
||||||
- [x] 23 unit + integration tests covering all code paths
|
|
||||||
- **Files:** `app/core/agent_runner.py`, `app/api/routes/agents.py`, `app/api/routes/device_ws.py`, `requirements.txt`, `tests/test_agent_runner.py`
|
|
||||||
- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls — stub until Step 3.6).
|
|
||||||
|
|
||||||
### Step 3.5 — Chatbot Journey endpoint
|
|
||||||
- [x] Create `app/api/routes/agent_setup.py`:
|
|
||||||
- `POST /api/v1/agents/journey/start`:
|
|
||||||
- Body: `{ agent_type: "local"|"cloud", agent_id: str | None }`
|
|
||||||
- `agent_type`: which kind of agent this journey configures.
|
|
||||||
- `agent_id`: optional — if provided, the session is pre-seeded with the existing agent's `prompt_template` so the user can refine it. If absent, fresh journey.
|
|
||||||
- **No `data_types` field** — data types are determined through the conversation itself, not sent upfront.
|
|
||||||
- Creates a journey session (in-memory or Redis-backed)
|
|
||||||
- Returns first AI message: contextual question based on agent type
|
|
||||||
- Local: "What kind of files are in the directories you want to monitor? (emails, documents, logs, etc.)"
|
|
||||||
- Cloud: "What kind of emails/messages should I look for? (client communications, invoices, meeting notes, etc.)"
|
|
||||||
- Response: `{ session_id, message, done: false }`
|
|
||||||
- **Electron note:** `proxyPost` auto-converts camelCase keys to snake_case. Electron sends `{ agentType, agentId }` → backend receives `{ agent_type, agent_id }`.
|
|
||||||
- `POST /api/v1/agents/journey/message`:
|
|
||||||
- Body: `{ session_id, message }`
|
|
||||||
- AI processes user's answer, asks follow-up questions (max 5 turns)
|
|
||||||
- System prompt: "You are configuring a data extraction agent for a freelancer. Ask about file format, what data to extract (tasks, notes, timelines), naming conventions, priority rules, and any special mapping. After 3-5 questions, generate a detailed prompt_template."
|
|
||||||
- When AI determines enough context: `{ session_id, message: "Here's your configuration...", done: true, prompt_template: "..." }`
|
|
||||||
- The `prompt_template` is a structured instruction for the extraction LLM (e.g. "Extract tasks from email. Subject becomes task title. If body contains 'urgent' or 'ASAP', set priority to 'high'. Extract due dates if mentioned.")
|
|
||||||
- **Electron note:** `toCamelCase` converts the response → Electron reads `promptTemplate` from the final message and auto-fills the agent config panel. User clicks "Save & apply" which calls `agent.local.update` / `agent.cloud.update` tRPC mutation.
|
|
||||||
- **Files:** `app/api/routes/agent_setup.py`, `app/main.py`
|
|
||||||
- **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅
|
|
||||||
|
|
||||||
### Step 3.6 — Cloud provider integrations
|
|
||||||
- [x] Create `app/integrations/gmail.py`:
|
|
||||||
- `GmailClient`:
|
|
||||||
- `__init__(oauth_token)` — initializes Google API client
|
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]`
|
|
||||||
- `EmailMessage`: `{ id, subject, sender, body_text, date, labels }`
|
|
||||||
- Handles token refresh via Google OAuth2 refresh flow
|
|
||||||
- Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders`
|
|
||||||
- [x] Create `app/integrations/ms_graph.py`:
|
|
||||||
- `MSGraphClient`:
|
|
||||||
- `__init__(oauth_token)` — initializes MS Graph client
|
|
||||||
- `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook)
|
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams)
|
|
||||||
- `ChatMessage`: `{ id, content, sender, channel, date }`
|
|
||||||
- Handles token refresh via MSAL
|
|
||||||
- [x] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient`
|
|
||||||
- **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal`
|
|
||||||
- **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py`
|
|
||||||
- **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams.
|
|
||||||
|
|
||||||
### Step 3.7 — Agent scheduler
|
|
||||||
- [ ] Create `app/core/agent_scheduler.py`:
|
|
||||||
- Uses `APScheduler` (or simple asyncio loop) to check agent schedules
|
|
||||||
- Every 60s: query enabled agents where `last_run_at + cron_interval < now()`
|
|
||||||
- For each due agent:
|
|
||||||
- Check if user's device is online via `DeviceConnectionManager`
|
|
||||||
- If online: dispatch to `agent_runner`
|
|
||||||
- If offline: skip (will trigger on next `device_hello`)
|
|
||||||
- Locks: use PostgreSQL advisory locks to prevent duplicate runs in multi-instance deployments
|
|
||||||
- [ ] Integrate with FastAPI lifespan (start scheduler on app startup, shutdown gracefully)
|
|
||||||
- **Dependencies:** `apscheduler>=4.0`
|
|
||||||
- **Files:** `app/core/agent_scheduler.py`, `app/main.py`
|
|
||||||
- **Outcome:** Agents run automatically on their configured schedules.
|
|
||||||
|
|
||||||
### Step 3.8 — OAuth flow endpoints
|
|
||||||
- [ ] Create `app/api/routes/oauth.py`:
|
|
||||||
- `GET /api/v1/oauth/{provider}/authorize` — returns OAuth authorization URL
|
|
||||||
- Gmail: Google OAuth2 with `gmail.readonly` scope
|
|
||||||
- Outlook/Teams: MS identity platform with `Mail.Read`, `ChannelMessage.Read.All` scopes
|
|
||||||
- `GET /api/v1/oauth/{provider}/callback` — handles OAuth redirect
|
|
||||||
- Exchanges auth code for access + refresh tokens
|
|
||||||
- Encrypts tokens with Fernet (server-side key from settings)
|
|
||||||
- Returns encrypted token blob for storage in `CloudAgentConfig.oauth_token_encrypted`
|
|
||||||
- `POST /api/v1/oauth/{provider}/refresh` — refresh expired OAuth token
|
|
||||||
- **Files:** `app/api/routes/oauth.py`, `app/main.py`
|
|
||||||
- **Outcome:** Users can connect Gmail/Teams/Outlook accounts securely.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Phase 3 — Verification
|
|
||||||
|
|
||||||
| # | Scenario | Expected |
|
|
||||||
|---|---|---|
|
|
||||||
| 1 | **Agent CRUD** | Create/read/update/delete local and cloud configs; tier limits enforced (free=2, pro=10) |
|
|
||||||
| 2 | **WS device connect** | Electron connects → `device_hello` → backend stores connection → triggers overdue runs |
|
|
||||||
| 3 | **Local agent run** | Backend sends `agent_run` → Electron reads files → `agent_data` → backend AI extracts → `tool_call(insert)` → Electron persists with `isAiSuggested=1` |
|
|
||||||
| 4 | **Cloud agent run** | Backend fetches Gmail → AI extracts tasks → `tool_call(insert)` → Electron persists |
|
|
||||||
| 5 | **Device binding** | Local agent config with `device_id=A` only triggers when device A is connected |
|
|
||||||
| 6 | **Chatbot Journey** | Start journey → 3-5 Q&A turns → produces valid `prompt_template` |
|
|
||||||
| 7 | **Schedule** | Agent with `schedule_cron="0 */6 * * *"` runs every 6h when device is online |
|
|
||||||
| 8 | **Offline resilience** | Device offline → runs skipped → device reconnects → overdue runs trigger immediately |
|
|
||||||
| 9 | **OAuth flow** | Gmail authorize → callback → token encrypted → stored in config → fetch emails works |
|
|
||||||
|
|
||||||
### Phase 3 — New Dependencies
|
|
||||||
|
|
||||||
| Package | Purpose |
|
|
||||||
|---|---|
|
|
||||||
| `google-api-python-client` | Gmail API access |
|
|
||||||
| `google-auth-oauthlib` | Gmail OAuth2 flow |
|
|
||||||
| `msgraph-sdk` | Outlook + Teams API access |
|
|
||||||
| `msal` | MS identity platform auth |
|
|
||||||
| `apscheduler>=4.0` | Agent scheduling |
|
|
||||||
| `cryptography` (Fernet) | OAuth token encryption at rest |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ~~Phase 5 — Shared Memory~~ (SUPERSEDED)
|
|
||||||
|
|
||||||
> **This phase has been fully replaced by `V3_MIGRATION_PLAN.md`.**
|
|
||||||
>
|
|
||||||
> - Chat WS fix → V3 Step 5 (Unified WS Handler — single multiplexed socket)
|
|
||||||
> - Agent memory → V3 Steps 6–7 (Cloud-side MemGPT-style memory in PostgreSQL + pgvector, encrypted at rest with per-user Fernet key)
|
|
||||||
>
|
|
||||||
> The on-device KV approach (Electron SQLite `agent_memory` table) is no longer the target architecture.
|
|
||||||
> See `V3_MIGRATION_PLAN.md` for the current plan.
|
|
||||||
572
BACKEND_PLAN.md
572
BACKEND_PLAN.md
@@ -1,572 +0,0 @@
|
|||||||
# Backend Plan — Adiuva Cloud API
|
|
||||||
|
|
||||||
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
|
||||||
>
|
|
||||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
|
|
||||||
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── app/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── main.py # FastAPI entry + CORS + lifespan + router includes
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── agent_registry.py # Base classes + singleton registry
|
|
||||||
│ │ ├── orchestrator.py # LLM-based intent router
|
|
||||||
│ │ ├── execution_plan.py # Plan builder + cache
|
|
||||||
│ │ └── plugin_loader.py # Dynamic agent loading
|
|
||||||
│ ├── agents/ # Chat agents (proprietary logic + prompts)
|
|
||||||
│ │ ├── __init__.py # Auto-registers all agents
|
|
||||||
│ │ ├── task_agent.py
|
|
||||||
│ │ ├── calendar_agent.py
|
|
||||||
│ │ ├── email_agent.py
|
|
||||||
│ │ └── analytics_agent.py
|
|
||||||
│ ├── api/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── routes/
|
|
||||||
│ │ │ ├── __init__.py
|
|
||||||
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
|
||||||
│ │ │ ├── plans.py # GET /plans/playbook
|
|
||||||
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
|
|
||||||
│ │ │ ├── vectors.py # Upsert/search cloud vector store
|
|
||||||
│ │ │ ├── backup.py # PUT/GET /backup
|
|
||||||
│ │ │ ├── plugins.py # Plugin marketplace
|
|
||||||
│ │ │ ├── auth.py # Register/login/refresh
|
|
||||||
│ │ │ └── billing.py # Checkout/webhook/subscription
|
|
||||||
│ │ └── middleware/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── auth.py # JWT validation
|
|
||||||
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
|
||||||
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
|
||||||
│ ├── storage/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
|
|
||||||
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
|
|
||||||
│ │ └── encryption.py # Integrity verification only — NO decryption
|
|
||||||
│ ├── marketplace/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
|
|
||||||
│ │ ├── plugin_review.py # Review queue + approval workflow
|
|
||||||
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
|
|
||||||
│ ├── billing/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
|
||||||
│ │ └── tier_manager.py # Feature matrix per tier
|
|
||||||
│ └── config/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ └── settings.py # Pydantic BaseSettings (env-based)
|
|
||||||
├── tests/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── conftest.py # Fixtures: test client, mock agents, mock LLM
|
|
||||||
│ ├── test_orchestrator.py
|
|
||||||
│ ├── test_agents.py
|
|
||||||
│ ├── test_auth.py
|
|
||||||
│ ├── test_backup.py
|
|
||||||
│ ├── test_storage.py
|
|
||||||
│ └── test_plugins.py
|
|
||||||
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
|
|
||||||
│ ├── alembic.ini
|
|
||||||
│ └── versions/
|
|
||||||
├── requirements.txt
|
|
||||||
├── Dockerfile
|
|
||||||
├── docker-compose.yml # App + PostgreSQL + Redis (dev)
|
|
||||||
├── .env.example
|
|
||||||
└── README.md
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step-by-Step Implementation
|
|
||||||
|
|
||||||
### Step 1 — Project scaffolding ✅
|
|
||||||
- [x] Initialize repo with the directory structure above
|
|
||||||
- [x] Write `requirements.txt`:
|
|
||||||
```
|
|
||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
langchain>=0.3.0
|
|
||||||
langchain-openai>=0.3.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
stripe>=11.0.0
|
|
||||||
boto3>=1.35.0
|
|
||||||
slowapi>=0.1.9
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
alembic>=1.14.0
|
|
||||||
bcrypt>=4.2.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
httpx>=0.28.0
|
|
||||||
websockets>=14.0
|
|
||||||
pytest>=8.0.0
|
|
||||||
pytest-asyncio>=0.24.0
|
|
||||||
```
|
|
||||||
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
|
||||||
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
|
||||||
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
|
||||||
- [x] Write `.env.example`
|
|
||||||
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
|
|
||||||
|
|
||||||
### Step 2 — Pydantic schemas (API contracts) ✅
|
|
||||||
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
|
||||||
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
|
||||||
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
|
||||||
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
|
||||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
|
|
||||||
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
|
||||||
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
|
||||||
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
|
||||||
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
|
||||||
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
|
||||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
|
||||||
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
|
|
||||||
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
|
|
||||||
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
|
|
||||||
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
|
|
||||||
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
|
|
||||||
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
|
|
||||||
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
|
|
||||||
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
|
|
||||||
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
|
|
||||||
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
|
|
||||||
- `PluginInstallRequest`: `plugin_id: str`
|
|
||||||
- **Outcome:** All request/response models defined and validated.
|
|
||||||
|
|
||||||
### Step 3 — Agent Registry + base classes ✅
|
|
||||||
- [x] `app/core/agent_registry.py`:
|
|
||||||
- `BaseAgent(ABC)`:
|
|
||||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
|
||||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
|
||||||
- `ChatAgent(BaseAgent)`:
|
|
||||||
- Abstract `async handle(query: str, context: dict) -> str`
|
|
||||||
- Abstract `get_tools() -> list` (LangChain tool definitions)
|
|
||||||
- Concrete `_tool_loop(llm, messages, tools, max_iter=5) -> str` — shared tool-calling loop
|
|
||||||
- `AgentRegistry` (singleton):
|
|
||||||
- `_agents: dict[str, ChatAgent]`
|
|
||||||
- `register(agent_class)` — decorator pattern
|
|
||||||
- `get(name) -> ChatAgent`
|
|
||||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
|
||||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
|
||||||
- [x] Unit tests: register, get, list, call_agent with mock
|
|
||||||
- **Outcome:** Pluggable agent framework.
|
|
||||||
|
|
||||||
### Step 4 — Orchestrator ✅
|
|
||||||
- [x] `app/core/orchestrator.py`:
|
|
||||||
- `async classify_intent(message, context, registry) -> str`:
|
|
||||||
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
|
||||||
- Uses gpt-4o-mini via LangChain for low latency
|
|
||||||
- Falls back to `task_agent` if no clear match
|
|
||||||
- `async route_single(agent_name, message, context) -> ChatResponse`:
|
|
||||||
- Instantiates agent from registry
|
|
||||||
- Calls `agent.handle(message, context)`
|
|
||||||
- Returns response + any actions the agent produced
|
|
||||||
- `async route_pipeline(agent_names, message, context) -> ChatResponse`:
|
|
||||||
- Executes agents in sequence
|
|
||||||
- Each agent receives `{...context, previous_results: [...]}`
|
|
||||||
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
|
||||||
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
|
||||||
- Main entry point
|
|
||||||
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
|
|
||||||
- Classifies intent
|
|
||||||
- If `execution_mode == 'direct'`: route + return response
|
|
||||||
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
|
||||||
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
|
||||||
- Same as orchestrate but yields tokens for WebSocket streaming
|
|
||||||
- [x] Integration tests with mocked LLM and mocked agents
|
|
||||||
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
|
||||||
|
|
||||||
### Step 5 — Execution Plan generator ✅
|
|
||||||
- [x] `app/core/execution_plan.py`:
|
|
||||||
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
|
||||||
- `ExecutionPlanBuilder`:
|
|
||||||
- `add_step(action, params) -> self`
|
|
||||||
- `add_llm_step(template_id, variables) -> self`
|
|
||||||
- `add_data_step(action, data_from_step) -> self`
|
|
||||||
- `build() -> ExecutionPlan` — validates step references
|
|
||||||
- `PlanCache`:
|
|
||||||
- In-memory LRU (maxsize=1000)
|
|
||||||
- `cache_plan(key, plan)`, `get_plan(key)`, `get_all_playbooks() -> list[ExecutionPlan]`
|
|
||||||
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
|
||||||
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
|
||||||
|
|
||||||
### Step 6 — Chat Agents ✅
|
|
||||||
- [x] `app/agents/task_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
|
|
||||||
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
|
|
||||||
- Accepts flexible context; sentinel `-1` for optional integer update fields
|
|
||||||
- [x] `app/agents/timeline_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages project timelines (milestones): list, create, update, delete"
|
|
||||||
- Tools (4): `list_timelines(project_id)`, `create_timeline(project_id, title, date, is_ai_suggested, is_approved)`, `update_timeline(timeline_id, ...)`, `delete_timeline(timeline_id)`
|
|
||||||
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
|
|
||||||
- [x] `app/agents/project_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
|
|
||||||
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
|
|
||||||
- [x] `app/agents/note_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages notes: list, get, create, update, delete"
|
|
||||||
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
|
|
||||||
- content is Markdown; `get_note` should be called before update to preserve existing content
|
|
||||||
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
|
|
||||||
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
|
|
||||||
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Timelines, Projects, Notes), all registered and tested.
|
|
||||||
|
|
||||||
### Step 7 — Storage Layer ✅
|
|
||||||
- [x] `app/storage/blob_store.py`:
|
|
||||||
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
|
|
||||||
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
|
|
||||||
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
|
|
||||||
- [x] `app/storage/vector_store.py`:
|
|
||||||
- `VectorStore`: `async upsert`, `async search`, `async delete`
|
|
||||||
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
|
|
||||||
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
|
|
||||||
- ANN on encrypted data: known accuracy trade-off, documented
|
|
||||||
- [x] `app/storage/encryption.py`:
|
|
||||||
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
|
|
||||||
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
|
|
||||||
- Backend NEVER holds decryption keys
|
|
||||||
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
|
|
||||||
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
|
|
||||||
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
|
|
||||||
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
|
|
||||||
|
|
||||||
### Step 8 — API Routes ✅
|
|
||||||
|
|
||||||
#### 8a — Chat endpoint
|
|
||||||
- [x] `app/api/routes/chat.py`:
|
|
||||||
- `POST /api/v1/chat`:
|
|
||||||
- Request: `ChatRequest`
|
|
||||||
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
|
||||||
- Response: `ChatResponse` or `ExecutionPlan`
|
|
||||||
- `WebSocket /api/v1/chat/stream`:
|
|
||||||
- Client sends `ChatRequest` as first JSON frame
|
|
||||||
- Server yields token strings via `orchestrate_stream()`
|
|
||||||
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
|
||||||
- Heartbeat ping every 30s to keep connection alive
|
|
||||||
|
|
||||||
#### 8b — Plans endpoint
|
|
||||||
- [x] `app/api/routes/plans.py`:
|
|
||||||
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
|
||||||
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
|
||||||
|
|
||||||
#### 8c — Storage endpoint (cloud records)
|
|
||||||
- [x] `app/api/routes/storage.py`:
|
|
||||||
- `POST /api/v1/storage/records`: Create encrypted record
|
|
||||||
- Request: `StorageRecordCreate`
|
|
||||||
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
|
|
||||||
- Response: `{id: str, created_at: int}`
|
|
||||||
- `GET /api/v1/storage/records`: List record metadata (no blobs)
|
|
||||||
- Query params: `table: str`, `page: int`, `limit: int`
|
|
||||||
- Response: `list[{id, table, checksum, created_at, updated_at}]`
|
|
||||||
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
|
|
||||||
- Response: blob bytes + `X-Checksum` header
|
|
||||||
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
|
|
||||||
- Request: `StorageRecordUpdate`
|
|
||||||
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
|
|
||||||
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
|
|
||||||
|
|
||||||
#### 8d — Vectors endpoint (cloud vector store)
|
|
||||||
- [x] `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/upsert`:
|
|
||||||
- Request: `VectorUpsertRequest`
|
|
||||||
- Verifies checksums, delegates to `VectorStore.upsert()`
|
|
||||||
- Response: `{upserted: int}`
|
|
||||||
- `POST /api/v1/storage/vectors/search`:
|
|
||||||
- Request: `VectorSearchRequest`
|
|
||||||
- Delegates to `VectorStore.search()`
|
|
||||||
- Response: `VectorSearchResponse`
|
|
||||||
- `DELETE /api/v1/storage/vectors`:
|
|
||||||
- Request: `{ids: list[str]}`
|
|
||||||
|
|
||||||
#### 8e — Backup endpoint
|
|
||||||
- [x] `app/api/routes/backup.py`:
|
|
||||||
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
|
||||||
- Free: 0 (no backup)
|
|
||||||
- Pro: 5 GB
|
|
||||||
- Power: 25 GB
|
|
||||||
- Team: unlimited
|
|
||||||
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
|
||||||
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
|
||||||
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
|
||||||
|
|
||||||
#### 8f — Plugins endpoint
|
|
||||||
- [x] `app/api/routes/plugins.py`:
|
|
||||||
- `GET /api/v1/plugins`:
|
|
||||||
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
|
|
||||||
- Response: `PluginListResponse`
|
|
||||||
- Available from Power tier and above
|
|
||||||
- `GET /api/v1/plugins/{id}`:
|
|
||||||
- Response: `PluginManifest` + ratings + install count
|
|
||||||
- `POST /api/v1/plugins/{id}/install`:
|
|
||||||
- Request: `PluginInstallRequest`
|
|
||||||
- Records installation for the user (billing tracking, analytics)
|
|
||||||
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
|
|
||||||
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
|
|
||||||
- `DELETE /api/v1/plugins/{id}/install`:
|
|
||||||
- Unregisters installation
|
|
||||||
|
|
||||||
#### 8g — Auth endpoint
|
|
||||||
- [x] `app/api/routes/auth.py`:
|
|
||||||
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
|
||||||
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
|
||||||
|
|
||||||
#### 8h — Billing endpoint
|
|
||||||
- [x] `app/api/routes/billing.py`:
|
|
||||||
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
|
||||||
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
|
||||||
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
|
||||||
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
|
||||||
|
|
||||||
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
|
|
||||||
|
|
||||||
### Step 9 — Middleware
|
|
||||||
|
|
||||||
#### 9a — Auth middleware
|
|
||||||
- [x] `app/api/middleware/auth.py`:
|
|
||||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
|
||||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
|
||||||
- Raises `401` on invalid/expired token
|
|
||||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
|
||||||
|
|
||||||
#### 9b — Rate limiter
|
|
||||||
- [x] `app/api/middleware/rate_limit.py`:
|
|
||||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
|
||||||
- Tier-based limits:
|
|
||||||
- Free: 20 req/min
|
|
||||||
- Pro: 60 req/min
|
|
||||||
- Power: 120 req/min
|
|
||||||
- Team: 200 req/seat/min
|
|
||||||
- Custom 429 response with `Retry-After` header
|
|
||||||
|
|
||||||
#### 9c — Sanitizer
|
|
||||||
- [x] `app/api/middleware/sanitizer.py`:
|
|
||||||
- Response middleware that scans response bodies
|
|
||||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
|
||||||
- Pattern-based detection + exact match against known prompt fingerprints
|
|
||||||
- Logs sanitization events for monitoring
|
|
||||||
|
|
||||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
|
||||||
|
|
||||||
### Step 10 — Plugin Marketplace ✅
|
|
||||||
- [x] `app/marketplace/plugin_registry.py`:
|
|
||||||
- `PluginRegistry`:
|
|
||||||
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
|
||||||
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
|
||||||
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
|
||||||
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
|
||||||
- `async reject_plugin(plugin_id, reason: str) -> None`
|
|
||||||
- [x] `app/marketplace/plugin_review.py`:
|
|
||||||
- `ReviewQueue`:
|
|
||||||
- `async get_pending() -> list[dict]`
|
|
||||||
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
|
||||||
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
|
||||||
- [x] `app/marketplace/revenue_share.py`:
|
|
||||||
- `RevenueShare`:
|
|
||||||
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
|
||||||
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
|
||||||
- `async get_earnings(developer_id, period) -> dict`
|
|
||||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
|
||||||
|
|
||||||
### Step 11 — Billing & Tier management ✅
|
|
||||||
- [x] `app/billing/stripe_service.py`:
|
|
||||||
- `create_checkout_session(user_id, tier) -> str`
|
|
||||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
|
||||||
- `get_subscription(user_id) -> dict | None`
|
|
||||||
- `cancel_subscription(user_id) -> None`
|
|
||||||
- [x] `app/billing/tier_manager.py`:
|
|
||||||
- `TierManager`:
|
|
||||||
- Feature matrix:
|
|
||||||
```python
|
|
||||||
FEATURES = {
|
|
||||||
'free': {
|
|
||||||
'agents': 3,
|
|
||||||
'batch_active': 2,
|
|
||||||
'cloud_storage_gb': 0,
|
|
||||||
'backup_gb': 0,
|
|
||||||
'providers': 1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'pro': {
|
|
||||||
'agents': -1, # unlimited
|
|
||||||
'batch_active': 10,
|
|
||||||
'cloud_storage_gb': 5,
|
|
||||||
'backup_gb': 5,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'power': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1, # unlimited
|
|
||||||
'cloud_storage_gb': 25,
|
|
||||||
'backup_gb': 25,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'team': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1,
|
|
||||||
'cloud_storage_gb': -1,
|
|
||||||
'backup_gb': -1,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- `get_tier(user_id) -> BillingTier`
|
|
||||||
- `check_feature(user_id, feature) -> bool`
|
|
||||||
- `get_rate_limit(tier) -> int`
|
|
||||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
|
||||||
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
|
|
||||||
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
|
|
||||||
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
|
|
||||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
|
||||||
|
|
||||||
### Step 12 — Database (auth/billing/marketplace only)
|
|
||||||
- [x] PostgreSQL schema via Alembic:
|
|
||||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
|
||||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
|
||||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
|
||||||
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
|
||||||
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
|
|
||||||
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
|
|
||||||
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
|
||||||
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
|
||||||
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
|
||||||
- [x] Initial Alembic migration
|
|
||||||
- [x] SQLAlchemy models in `app/models.py`
|
|
||||||
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
|
||||||
|
|
||||||
### Step 13 — Testing & deployment ✅
|
|
||||||
- [x] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
|
|
||||||
- [x] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
|
||||||
- [x] `tests/test_agents.py`: each agent with mocked tools
|
|
||||||
- [x] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
|
||||||
- [x] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
|
||||||
- [x] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
|
|
||||||
- [x] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
|
|
||||||
- [x] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
|
||||||
- [x] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
|
||||||
- **Outcome:** Fully tested, deployable backend.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API Contract Summary
|
|
||||||
|
|
||||||
| Method | Endpoint | Auth | Request | Response |
|
|
||||||
|--------|----------|------|---------|----------|
|
|
||||||
| POST | `/api/v1/auth/register` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/login` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/refresh` | No | `{refresh_token}` | `AuthTokens` |
|
|
||||||
| GET | `/api/v1/auth/me` | JWT | — | `UserProfile` |
|
|
||||||
| POST | `/api/v1/chat` | JWT | `ChatRequest` | `ChatResponse \| ExecutionPlan` |
|
|
||||||
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
|
||||||
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
|
||||||
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
|
||||||
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
|
|
||||||
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
|
|
||||||
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
|
|
||||||
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
|
|
||||||
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
|
|
||||||
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
|
|
||||||
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
|
||||||
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
|
||||||
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
|
|
||||||
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
|
|
||||||
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
|
|
||||||
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
|
||||||
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
|
||||||
| DELETE | `/api/v1/billing/subscription` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/health` | No | — | `{status, version}` |
|
|
||||||
| GET | `/api/v1/agents/catalog` | JWT | — | `AgentCatalogItem[]` |
|
|
||||||
| GET | `/api/v1/agents/local` | JWT | — | `LocalAgentConfigResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/local` | JWT | `LocalAgentConfigCreate` | `LocalAgentConfigResponse` |
|
|
||||||
| PUT | `/api/v1/agents/local/{id}` | JWT | `LocalAgentConfigUpdate` | `LocalAgentConfigResponse` |
|
|
||||||
| DELETE | `/api/v1/agents/local/{id}` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/agents/cloud` | JWT | — | `CloudAgentConfigResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/cloud` | JWT | `CloudAgentConfigCreate` | `CloudAgentConfigResponse` |
|
|
||||||
| PUT | `/api/v1/agents/cloud/{id}` | JWT | `CloudAgentConfigUpdate` | `CloudAgentConfigResponse` |
|
|
||||||
| DELETE | `/api/v1/agents/cloud/{id}` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/agents/runs` | JWT | `?agent_id&page&limit` | `AgentRunLogResponse[]` |
|
|
||||||
| POST | `/api/v1/agents/{id}/run` | JWT | — | `{ok: true, run_id}` |
|
|
||||||
| POST | `/api/v1/agents/journey/start` | JWT | `{agent_type, data_types}` | `{session_id, message, done}` |
|
|
||||||
| POST | `/api/v1/agents/journey/message` | JWT | `{session_id, message}` | `{session_id, message, done, prompt_template?}` |
|
|
||||||
| GET | `/api/v1/oauth/{provider}/authorize` | JWT | — | `{authorization_url}` |
|
|
||||||
| GET | `/api/v1/oauth/{provider}/callback` | — | OAuth code | `{encrypted_token}` |
|
|
||||||
| WS | `/api/v1/ws/device` | JWT | `device_hello` (first frame) | Agent trigger + tool_call frames |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stack
|
|
||||||
|
|
||||||
| Layer | Technology |
|
|
||||||
|-------|-----------|
|
|
||||||
| Framework | FastAPI + Uvicorn |
|
|
||||||
| LLM | LangChain + langchain-openai |
|
|
||||||
| Auth | PyJWT + bcrypt + OAuth2 |
|
|
||||||
| Billing | stripe-python + Stripe Connect |
|
|
||||||
| Blob storage | boto3 (S3) |
|
|
||||||
| Vector store | Pinecone or Qdrant (configurable) |
|
|
||||||
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
|
||||||
| Rate limiting | slowapi |
|
|
||||||
| Cloud integrations | google-api-python-client, msgraph-sdk, msal |
|
|
||||||
| Agent scheduling | APScheduler |
|
|
||||||
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
|
|
||||||
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3 — New Files
|
|
||||||
|
|
||||||
| File | Purpose |
|
|
||||||
|---|---|
|
|
||||||
| `app/models.py` | Add `LocalAgentConfig`, `CloudAgentConfig`, `AgentRunLog` models |
|
|
||||||
| `app/schemas.py` | Add agent config schemas + WS agent frame types |
|
|
||||||
| `app/api/routes/agents.py` | Agent CRUD endpoints (catalog, local, cloud, runs, manual trigger) |
|
|
||||||
| `app/api/routes/agent_setup.py` | Chatbot Journey endpoints (start + message) |
|
|
||||||
| `app/api/routes/device_ws.py` | Persistent device WS endpoint (`/api/v1/ws/device`) |
|
|
||||||
| `app/api/routes/oauth.py` | OAuth authorize/callback for Gmail, Teams, Outlook |
|
|
||||||
| `app/core/agent_runner.py` | Agent run orchestration — local (WS file request) + cloud (API fetch) |
|
|
||||||
| `app/core/device_manager.py` | `DeviceConnectionManager` — tracks active Electron WS connections |
|
|
||||||
| `app/core/agent_scheduler.py` | Periodic scheduler for agent cron triggers |
|
|
||||||
| `app/integrations/gmail.py` | Gmail API client (fetch messages with filters) |
|
|
||||||
| `app/integrations/ms_graph.py` | MS Graph client for Outlook emails + Teams messages |
|
|
||||||
| `app/integrations/__init__.py` | Provider factory |
|
|
||||||
|
|
||||||
> **Full Phase 3 step-by-step plan:** See `AI_REFACTOR_PLAN.md` Phase 3 section.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Development Rules
|
|
||||||
|
|
||||||
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
|
|
||||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
|
|
||||||
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
|
|
||||||
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
|
||||||
5. **Type hints everywhere.** All functions have full type annotations.
|
|
||||||
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
|
||||||
7. **Structured logging.** JSON logs with request ID correlation.
|
|
||||||
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
|
|
||||||
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
|
||||||
793
README.md
793
README.md
@@ -1,793 +0,0 @@
|
|||||||
# Adiuva Cloud API
|
|
||||||
|
|
||||||
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
|
|
||||||
|
|
||||||
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Table of Contents
|
|
||||||
|
|
||||||
- [Overview](#overview)
|
|
||||||
- [Architecture](#architecture)
|
|
||||||
- [Key Features](#key-features)
|
|
||||||
- [Tech Stack](#tech-stack)
|
|
||||||
- [Getting Started](#getting-started)
|
|
||||||
- [Docker Deployment](#docker-deployment)
|
|
||||||
- [Environment Variables](#environment-variables)
|
|
||||||
- [API Reference](#api-reference)
|
|
||||||
- [Data Model](#data-model)
|
|
||||||
- [AI Agent System](#ai-agent-system)
|
|
||||||
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
|
||||||
- [Middleware](#middleware)
|
|
||||||
- [Storage Layer](#storage-layer)
|
|
||||||
- [Billing & Tiers](#billing--tiers)
|
|
||||||
- [Plugin Marketplace](#plugin-marketplace)
|
|
||||||
- [Testing](#testing)
|
|
||||||
- [Project Structure](#project-structure)
|
|
||||||
- [License](#license)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
|
|
||||||
|
|
||||||
### Design Principles
|
|
||||||
|
|
||||||
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
|
|
||||||
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
|
||||||
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
|
|
||||||
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
|
||||||
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
|
|
||||||
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
|
|
||||||
│ Desktop App │────▶│ │
|
|
||||||
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
|
|
||||||
└──────────────┘ │ │
|
|
||||||
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
|
||||||
│ │ Auth Routes │ │ Chat Routes │ │
|
|
||||||
│ │ Billing Routes │ │ ↓ │ │
|
|
||||||
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
|
||||||
│ │ Backup Routes │ │ ↓ classify intent │ │
|
|
||||||
│ │ Plugin Routes │ │ Agent Registry │ │
|
|
||||||
│ │ Vector Routes │ │ ↓ │ │
|
|
||||||
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
|
|
||||||
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
|
|
||||||
│ │ (GPT-4o + LangChain) │ │
|
|
||||||
│ └────────────────────────────┘ │
|
|
||||||
└────────────────────────────────────────────────────────┘
|
|
||||||
│ │ │
|
|
||||||
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
|
|
||||||
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
|
|
||||||
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
|
|
||||||
│ Billing, │ │ backups) │ │ (Vectors) │
|
|
||||||
│ Metadata) │ └───────────────┘ └────────────────┘
|
|
||||||
└────────────┘
|
|
||||||
│
|
|
||||||
┌────────▼───┐
|
|
||||||
│ Stripe │
|
|
||||||
│ (Billing, │
|
|
||||||
│ Connect) │
|
|
||||||
└────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Key Features
|
|
||||||
|
|
||||||
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
|
||||||
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
|
||||||
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
|
||||||
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
|
||||||
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
|
||||||
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
|
|
||||||
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
|
|
||||||
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
|
||||||
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
|
||||||
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
|
||||||
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
|
||||||
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
|
|
||||||
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
|
||||||
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
|
|
||||||
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tech Stack
|
|
||||||
|
|
||||||
| Package | Version | Purpose |
|
|
||||||
|---|---|---|
|
|
||||||
| `fastapi` | ≥ 0.115.0 | Web framework |
|
|
||||||
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
|
|
||||||
| `gunicorn` | ≥ 22.0.0 | Production process manager |
|
|
||||||
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
|
|
||||||
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
|
|
||||||
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
|
|
||||||
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
|
|
||||||
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
|
||||||
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
|
||||||
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
|
||||||
| `boto3` | ≥ 1.35.0 | AWS S3 client |
|
|
||||||
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
|
||||||
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
|
||||||
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
|
||||||
| `alembic` | ≥ 1.14.0 | Database migration management |
|
|
||||||
| `bcrypt` | ≥ 4.2.0 | Password hashing |
|
|
||||||
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
|
|
||||||
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
|
||||||
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
|
||||||
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
|
||||||
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
|
|
||||||
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
|
|
||||||
| `pytest` | ≥ 8.0.0 | Test framework |
|
|
||||||
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
|
||||||
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
|
||||||
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
|
|
||||||
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Getting Started
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
- Python 3.12+
|
|
||||||
- PostgreSQL 16+
|
|
||||||
- An OpenAI API key (for LLM features)
|
|
||||||
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
|
||||||
- AWS credentials (optional — needed for S3 storage in production)
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Clone the repository
|
|
||||||
git clone <repo-url> && cd adiuva-api
|
|
||||||
|
|
||||||
# Create a virtual environment
|
|
||||||
python -m venv .venv && source .venv/bin/activate
|
|
||||||
|
|
||||||
# Install dependencies
|
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
# Configure environment
|
|
||||||
cp .env.example .env
|
|
||||||
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
|
|
||||||
```
|
|
||||||
|
|
||||||
### Database Setup
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Start PostgreSQL (or use the Docker Compose database)
|
|
||||||
docker compose up db -d
|
|
||||||
|
|
||||||
# Run migrations
|
|
||||||
alembic upgrade head
|
|
||||||
```
|
|
||||||
|
|
||||||
### Run the Development Server
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
|
||||||
```
|
|
||||||
|
|
||||||
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Docker Deployment
|
|
||||||
|
|
||||||
### Quick Start
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose up --build
|
|
||||||
```
|
|
||||||
|
|
||||||
This starts two services:
|
|
||||||
|
|
||||||
- **app** — FastAPI server on port `8000`
|
|
||||||
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
|
||||||
|
|
||||||
The compose file also includes optional services for fully local deployments:
|
|
||||||
|
|
||||||
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
|
|
||||||
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
|
|
||||||
|
|
||||||
### Dockerfile Details
|
|
||||||
|
|
||||||
The Dockerfile uses a multi-stage build:
|
|
||||||
|
|
||||||
1. **Builder stage** — Installs Python dependencies into a virtual environment.
|
|
||||||
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
|
|
||||||
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Production command (run by the container)
|
|
||||||
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Homelab / Self-Hosted Deployment
|
|
||||||
|
|
||||||
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
|
|
||||||
|
|
||||||
### 1. Start all services
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
|
|
||||||
|
|
||||||
### 2. Create the MinIO bucket
|
|
||||||
|
|
||||||
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
|
|
||||||
docker compose exec minio mc mb local/adiuva
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Configure your `.env`
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Database (uses the compose PostgreSQL)
|
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
|
|
||||||
# S3 → MinIO
|
|
||||||
S3_BUCKET=adiuva
|
|
||||||
S3_REGION=us-east-1
|
|
||||||
S3_ENDPOINT_URL=http://minio:9000
|
|
||||||
AWS_ACCESS_KEY_ID=minioadmin
|
|
||||||
AWS_SECRET_ACCESS_KEY=minioadmin
|
|
||||||
|
|
||||||
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
|
|
||||||
QDRANT_URL=http://qdrant:6333
|
|
||||||
QDRANT_API_KEY=
|
|
||||||
PINECONE_API_KEY=
|
|
||||||
|
|
||||||
# Billing — leave empty to stub (no Stripe needed)
|
|
||||||
STRIPE_SECRET_KEY=
|
|
||||||
STRIPE_WEBHOOK_SECRET=
|
|
||||||
|
|
||||||
# LLM — the only external service
|
|
||||||
OPENAI_API_KEY=sk-...
|
|
||||||
LLM_MODEL=gpt-4o
|
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
|
||||||
# Auth
|
|
||||||
JWT_SECRET=your-secret-here
|
|
||||||
ENV=dev
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Run migrations
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose exec app alembic upgrade head
|
|
||||||
```
|
|
||||||
|
|
||||||
### What runs where
|
|
||||||
|
|
||||||
| Service | Runs on | Port | Notes |
|
|
||||||
|---|---|---|---|
|
|
||||||
| FastAPI app | Docker | 8000 | API server |
|
|
||||||
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
|
|
||||||
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
|
|
||||||
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
|
|
||||||
| Stripe | — | — | Stubbed when keys are empty |
|
|
||||||
| OpenAI / LLM | Cloud | — | Only external dependency |
|
|
||||||
|
|
||||||
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Environment Variables
|
|
||||||
|
|
||||||
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
|
|
||||||
|
|
||||||
| Variable | Type | Default | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
|
|
||||||
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
|
|
||||||
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
|
|
||||||
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
|
||||||
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
|
||||||
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
|
||||||
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
|
|
||||||
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
|
|
||||||
| `S3_REGION` | `str` | `us-east-1` | AWS region |
|
|
||||||
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
|
|
||||||
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
|
|
||||||
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
|
|
||||||
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
|
|
||||||
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
|
|
||||||
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
|
|
||||||
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
|
|
||||||
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
|
|
||||||
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
|
||||||
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
|
||||||
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
|
||||||
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API Reference
|
|
||||||
|
|
||||||
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
|
|
||||||
|
|
||||||
### Health
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
|
|
||||||
|
|
||||||
### Auth
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
|
|
||||||
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
|
|
||||||
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
|
|
||||||
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
|
|
||||||
|
|
||||||
### Chat
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
|
||||||
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
|
||||||
|
|
||||||
### Plans
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
|
||||||
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
|
||||||
|
|
||||||
### Storage (Cloud Records)
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
|
|
||||||
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
|
|
||||||
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
|
|
||||||
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
|
|
||||||
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
|
|
||||||
|
|
||||||
### Vectors (Cloud Vector Store)
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
|
|
||||||
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
|
|
||||||
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
|
|
||||||
|
|
||||||
### Backup
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
|
|
||||||
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
|
|
||||||
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
|
|
||||||
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
|
|
||||||
|
|
||||||
### Plugins (Marketplace)
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
|
|
||||||
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
|
|
||||||
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
|
|
||||||
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
|
|
||||||
|
|
||||||
### Billing
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
|
|
||||||
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
|
|
||||||
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
|
|
||||||
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Data Model
|
|
||||||
|
|
||||||
9 tables managed by Alembic migrations. Source: `app/models.py`
|
|
||||||
|
|
||||||
### Tables
|
|
||||||
|
|
||||||
| Table | Primary Key | Key Columns | Purpose |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
|
||||||
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
|
||||||
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
|
||||||
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
|
|
||||||
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
|
|
||||||
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
|
|
||||||
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
|
|
||||||
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
|
|
||||||
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
|
|
||||||
|
|
||||||
### Enum Types
|
|
||||||
|
|
||||||
| Enum | Values |
|
|
||||||
|---|---|
|
|
||||||
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
|
||||||
| `plugin_status` | `pending_review`, `approved`, `rejected` |
|
|
||||||
| `review_decision` | `approved`, `rejected` |
|
|
||||||
|
|
||||||
### Migrations
|
|
||||||
|
|
||||||
| Version | Description |
|
|
||||||
|---|---|
|
|
||||||
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
|
|
||||||
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## AI Agent System
|
|
||||||
|
|
||||||
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
|
|
||||||
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
|
||||||
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
|
||||||
|
|
||||||
### Registered Agents
|
|
||||||
|
|
||||||
| Agent | Registry Name | Tools | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
|
||||||
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
|
||||||
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
|
|
||||||
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
|
||||||
|
|
||||||
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
|
||||||
|
|
||||||
### Switching LLM Providers
|
|
||||||
|
|
||||||
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# OpenAI (default)
|
|
||||||
LLM_MODEL=gpt-4o
|
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
|
||||||
# Anthropic
|
|
||||||
LLM_MODEL=anthropic/claude-3.5-sonnet
|
|
||||||
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
|
|
||||||
|
|
||||||
# Google Gemini
|
|
||||||
LLM_MODEL=gemini/gemini-pro
|
|
||||||
LLM_ROUTER_MODEL=gemini/gemini-flash
|
|
||||||
|
|
||||||
# Local Ollama
|
|
||||||
LLM_MODEL=ollama/llama3
|
|
||||||
LLM_ROUTER_MODEL=ollama/llama3
|
|
||||||
|
|
||||||
# AWS Bedrock
|
|
||||||
LLM_MODEL=bedrock/anthropic.claude-v2
|
|
||||||
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
|
|
||||||
```
|
|
||||||
|
|
||||||
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Orchestration & Execution Plans
|
|
||||||
|
|
||||||
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
|
|
||||||
|
|
||||||
### Orchestrator
|
|
||||||
|
|
||||||
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
|
|
||||||
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
|
|
||||||
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
|
|
||||||
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
|
|
||||||
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
|
|
||||||
|
|
||||||
### Execution Plans
|
|
||||||
|
|
||||||
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
|
|
||||||
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
|
|
||||||
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
|
|
||||||
|
|
||||||
### Built-in Templates (6)
|
|
||||||
|
|
||||||
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
|
||||||
|
|
||||||
### Built-in Playbooks (2)
|
|
||||||
|
|
||||||
| Playbook | Description |
|
|
||||||
|---|---|
|
|
||||||
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
|
|
||||||
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Middleware
|
|
||||||
|
|
||||||
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
|
|
||||||
|
|
||||||
### JWT Authentication
|
|
||||||
|
|
||||||
Source: `app/api/middleware/auth.py`
|
|
||||||
|
|
||||||
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
|
|
||||||
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
|
|
||||||
- Falls back to `free` when no subscription row exists.
|
|
||||||
- Raises `401 Unauthorized` on invalid or expired tokens.
|
|
||||||
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
|
||||||
|
|
||||||
### Tier-Based Rate Limiter
|
|
||||||
|
|
||||||
Source: `app/api/middleware/rate_limit.py`
|
|
||||||
|
|
||||||
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
|
|
||||||
- Per-user 60-second window sized by subscription tier:
|
|
||||||
|
|
||||||
| Tier | Requests / Minute |
|
|
||||||
|---|---|
|
|
||||||
| Free | 20 |
|
|
||||||
| Pro | 60 |
|
|
||||||
| Power | 120 |
|
|
||||||
| Team | 200 |
|
|
||||||
|
|
||||||
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
|
|
||||||
- **Exempt paths:** register, login, webhook, health
|
|
||||||
|
|
||||||
### Response Sanitizer
|
|
||||||
|
|
||||||
Source: `app/api/middleware/sanitizer.py`
|
|
||||||
|
|
||||||
- Runs only on `/api/v1/chat` endpoints.
|
|
||||||
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
|
||||||
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
|
||||||
- Logs sanitization events as `WARNING`.
|
|
||||||
- Binary responses (storage, backup) are never touched.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Storage Layer
|
|
||||||
|
|
||||||
### Blob Store
|
|
||||||
|
|
||||||
Source: `app/storage/blob_store.py`
|
|
||||||
|
|
||||||
- S3-backed storage for E2E encrypted blobs.
|
|
||||||
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
|
|
||||||
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
|
|
||||||
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
|
|
||||||
- The backend **never inspects or decrypts blob content**.
|
|
||||||
|
|
||||||
### Vector Store
|
|
||||||
|
|
||||||
Source: `app/storage/vector_store.py`
|
|
||||||
|
|
||||||
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
|
|
||||||
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
|
|
||||||
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
|
|
||||||
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
|
|
||||||
- Methods: `upsert()`, `search()`, `delete()`
|
|
||||||
|
|
||||||
### Encryption Utilities
|
|
||||||
|
|
||||||
Source: `app/storage/encryption.py`
|
|
||||||
|
|
||||||
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
|
|
||||||
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
|
|
||||||
- **No decryption key ever reaches the backend.**
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Billing & Tiers
|
|
||||||
|
|
||||||
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
|
||||||
|
|
||||||
### Feature Matrix
|
|
||||||
|
|
||||||
| Feature | Free | Pro | Power | Team |
|
|
||||||
|---|---|---|---|---|
|
|
||||||
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
|
||||||
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
|
||||||
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
|
||||||
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
|
||||||
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
|
||||||
| Batch Builder | — | — | ✓ | ✓ |
|
|
||||||
| Plugin Marketplace | — | — | ✓ | ✓ |
|
|
||||||
| SSO | — | — | — | ✓ |
|
|
||||||
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
|
||||||
|
|
||||||
### Stripe Integration
|
|
||||||
|
|
||||||
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
|
|
||||||
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
|
|
||||||
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
|
|
||||||
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
|
|
||||||
|
|
||||||
### Tier Manager
|
|
||||||
|
|
||||||
- `get_tier(user_id)` — Returns the user's current billing tier.
|
|
||||||
- `check_feature(tier, feature)` — Boolean feature gate check.
|
|
||||||
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
|
||||||
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Plugin Marketplace
|
|
||||||
|
|
||||||
Source: `app/marketplace/`
|
|
||||||
|
|
||||||
### Plugin Registry
|
|
||||||
|
|
||||||
- PostgreSQL-backed catalog of submitted and approved plugins.
|
|
||||||
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
|
|
||||||
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
|
|
||||||
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
|
|
||||||
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
|
|
||||||
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
|
|
||||||
|
|
||||||
### Review Queue
|
|
||||||
|
|
||||||
- Automated security checklist before human review:
|
|
||||||
- Plugin ID must match `^[a-z0-9-]+$`
|
|
||||||
- Permissions must be from the allowed set only
|
|
||||||
- No binary blobs in the manifest
|
|
||||||
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
|
|
||||||
- `get_pending(db)` — Lists plugins awaiting review.
|
|
||||||
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
|
||||||
|
|
||||||
### Revenue Sharing
|
|
||||||
|
|
||||||
- **70% developer / 30% platform** split on all paid plugin sales.
|
|
||||||
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
|
|
||||||
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
|
|
||||||
- Gracefully stubs transfers when Stripe is not configured.
|
|
||||||
|
|
||||||
### Seed Plugins
|
|
||||||
|
|
||||||
| Plugin | Category | Price |
|
|
||||||
|---|---|---|
|
|
||||||
| GitHub Sync | Productivity | Free |
|
|
||||||
| Slack Notifier | Communication | €4.99 |
|
|
||||||
| Time Tracker | Productivity | €9.99 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
### Running Tests
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run all tests
|
|
||||||
pytest
|
|
||||||
|
|
||||||
# Run a specific test file
|
|
||||||
pytest tests/test_auth.py
|
|
||||||
|
|
||||||
# Run with verbose output
|
|
||||||
pytest -v
|
|
||||||
```
|
|
||||||
|
|
||||||
### Test Infrastructure
|
|
||||||
|
|
||||||
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
|
||||||
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
|
|
||||||
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
|
||||||
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
|
||||||
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
|
|
||||||
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
|
||||||
- **No external dependencies** — all tests run fully offline.
|
|
||||||
|
|
||||||
### Test Coverage
|
|
||||||
|
|
||||||
| File | Coverage |
|
|
||||||
|---|---|
|
|
||||||
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
|
||||||
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
|
|
||||||
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
|
|
||||||
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
|
|
||||||
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
|
|
||||||
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
|
|
||||||
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
|
|
||||||
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
|
|
||||||
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── alembic.ini # Alembic configuration
|
|
||||||
├── BACKEND_PLAN.md # Architecture & design decisions
|
|
||||||
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
|
||||||
├── Dockerfile # Multi-stage production build
|
|
||||||
├── requirements.txt # Python dependencies
|
|
||||||
│
|
|
||||||
├── alembic/ # Database migrations
|
|
||||||
│ ├── env.py # Alembic environment config
|
|
||||||
│ ├── script.py.mako # Migration template
|
|
||||||
│ └── versions/
|
|
||||||
│ ├── 001_initial_schema.py # Tables, indexes, FKs
|
|
||||||
│ └── 002_seed_plugins.py # Seed marketplace plugins
|
|
||||||
│
|
|
||||||
├── app/ # Application source
|
|
||||||
│ ├── main.py # FastAPI app factory, middleware, routes
|
|
||||||
│ ├── db.py # Async SQLAlchemy engine & session
|
|
||||||
│ ├── models.py # SQLAlchemy ORM models (9 tables)
|
|
||||||
│ ├── schemas.py # Pydantic request/response schemas
|
|
||||||
│ │
|
|
||||||
│ ├── config/
|
|
||||||
│ │ └── settings.py # Pydantic Settings (env vars)
|
|
||||||
│ │
|
|
||||||
│ ├── agents/ # LLM-powered domain agents
|
|
||||||
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
|
||||||
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
|
||||||
│ │ ├── timeline_agent.py # Milestones (4 tools)
|
|
||||||
│ │ └── note_agent.py # Markdown notes (5 tools)
|
|
||||||
│ │
|
|
||||||
│ ├── core/ # Orchestration engine
|
|
||||||
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
|
||||||
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
|
||||||
│ │ ├── orchestrator.py # Intent classification & routing
|
|
||||||
│ │ └── execution_plan.py # Plan builder, templates, cache
|
|
||||||
│ │
|
|
||||||
│ ├── api/ # HTTP layer
|
|
||||||
│ │ ├── deps.py # Shared FastAPI dependencies
|
|
||||||
│ │ ├── middleware/
|
|
||||||
│ │ │ ├── auth.py # JWT validation, live tier lookup
|
|
||||||
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
|
||||||
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
|
||||||
│ │ └── routes/
|
|
||||||
│ │ ├── auth.py # Register, login, refresh, me
|
|
||||||
│ │ ├── chat.py # Chat + WebSocket streaming
|
|
||||||
│ │ ├── plans.py # Execution plan playbooks
|
|
||||||
│ │ ├── storage.py # E2E encrypted record CRUD
|
|
||||||
│ │ ├── vectors.py # Vector upsert, search, delete
|
|
||||||
│ │ ├── backup.py # Encrypted backup management
|
|
||||||
│ │ ├── plugins.py # Marketplace browse & install
|
|
||||||
│ │ └── billing.py # Stripe checkout & webhooks
|
|
||||||
│ │
|
|
||||||
│ ├── storage/ # Storage backends
|
|
||||||
│ │ ├── blob_store.py # S3 blob storage
|
|
||||||
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
|
|
||||||
│ │ └── encryption.py # Checksum verification utilities
|
|
||||||
│ │
|
|
||||||
│ ├── billing/ # Subscription management
|
|
||||||
│ │ ├── stripe_service.py # Stripe API integration
|
|
||||||
│ │ └── tier_manager.py # Feature matrix & quota enforcement
|
|
||||||
│ │
|
|
||||||
│ └── marketplace/ # Plugin ecosystem
|
|
||||||
│ ├── plugin_registry.py # Catalog CRUD & search
|
|
||||||
│ ├── plugin_review.py # Security checklist & review queue
|
|
||||||
│ └── revenue_share.py # 70/30 split & Stripe Connect
|
|
||||||
│
|
|
||||||
└── tests/ # Test suite
|
|
||||||
├── conftest.py # Fixtures: DB, S3, auth, seeds
|
|
||||||
├── test_auth.py
|
|
||||||
├── test_orchestrator.py
|
|
||||||
├── test_agents.py
|
|
||||||
├── test_storage.py
|
|
||||||
├── test_backup.py
|
|
||||||
├── test_plugins.py
|
|
||||||
├── test_agent_registry.py
|
|
||||||
├── test_execution_plan.py
|
|
||||||
└── test_middleware.py
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
*To be determined.*
|
|
||||||
|
|||||||
@@ -1,353 +0,0 @@
|
|||||||
# V3 Migration Plan — Multi-Agent AI Productivity App
|
|
||||||
|
|
||||||
> Incremental migration from current architecture to v3.
|
|
||||||
> Each step is self-contained, testable, and backwards-compatible.
|
|
||||||
> No BYOK — server manages all LLM keys.
|
|
||||||
> Memory encryption: server-side per-user Fernet key (Option A).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## General Rules
|
|
||||||
|
|
||||||
**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes:
|
|
||||||
- Old functions/methods that are superseded by new ones
|
|
||||||
- Deprecated imports or modules
|
|
||||||
- Dead code paths
|
|
||||||
- Old test files no longer needed
|
|
||||||
|
|
||||||
This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Decisions Log
|
|
||||||
|
|
||||||
| Topic | Decision |
|
|
||||||
|---|---|
|
|
||||||
| WS topology | Single multiplexed socket (merge chat into device WS) |
|
|
||||||
| LLM keys | Server-managed only, no user key passthrough |
|
|
||||||
| Memory encryption | Per-user server-generated Fernet key, encrypted at rest, decrypted in-memory |
|
|
||||||
| device_manager | Already multi-user correct (keyed by user_id), no structural change |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 1 — WS Frame Protocol (schemas.py)
|
|
||||||
|
|
||||||
**Goal**: Define the v3 frame vocabulary so all subsequent steps can import it.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/schemas.py` — Add to `WsFrameType` enum:
|
|
||||||
- `home_request`, `floating_request`
|
|
||||||
- `stream_start`, `stream_text`, `stream_block`, `stream_end`
|
|
||||||
- `floating_domain`
|
|
||||||
- `data_request`, `data_response`, `mutation`
|
|
||||||
- Add Pydantic models:
|
|
||||||
- `WsHomeRequest(type, message, conversation_history?)`
|
|
||||||
- `WsFloatingRequest(type, message, scope: {type, id?})`
|
|
||||||
- `WsStreamStart(type, request_id)`
|
|
||||||
- `WsStreamText(type, request_id, chunk)`
|
|
||||||
- `WsStreamBlock(type, request_id, block_type, data)`
|
|
||||||
- `WsStreamEnd(type, request_id, mutations?)`
|
|
||||||
- `WsFloatingDomain(type, request_id, domain)`
|
|
||||||
- Keep all existing frame types (backward compat).
|
|
||||||
|
|
||||||
**Files touched**: `app/schemas.py`
|
|
||||||
|
|
||||||
**Test**: Unit test that validates each new model serializes/deserializes correctly.
|
|
||||||
```
|
|
||||||
pytest tests/test_schemas_v3.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 1 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-1: add v3 ws frame protocol (schemas.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 2 — Agent Streaming + Tool Result Capture (agent_registry.py, agents/)
|
|
||||||
|
|
||||||
**Goal**: Agents can stream LLM tokens and expose structured tool results.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/core/agent_registry.py`:
|
|
||||||
- Add `_tool_loop_stream()` to `ChatAgent` — same logic as `_tool_loop()` but the **final** LLM call (when no more tool calls) uses `llm.astream()` and yields tokens.
|
|
||||||
- Add `self.tool_results: list[dict]` attribute to `ChatAgent.__init__()`.
|
|
||||||
- In both `_tool_loop` and `_tool_loop_stream`, capture raw `execute_on_client` results when tools run (store in `self.tool_results`).
|
|
||||||
- `app/agents/*.py` — Each agent's tools already return text summaries. No change to tools. The raw data capture happens at the `_tool_loop` level by intercepting `ToolMessage` content that comes from `execute_on_client`.
|
|
||||||
|
|
||||||
**Files touched**: `app/core/agent_registry.py`
|
|
||||||
|
|
||||||
**Test**: Unit test with mocked LLM that verifies `_tool_loop_stream()` yields tokens and `agent.tool_results` contains structured data after a tool call.
|
|
||||||
```
|
|
||||||
pytest tests/test_agent_streaming.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 2 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-2: add agent streaming and tool result capture (agent_registry.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 3 — Router Refactor (orchestrator.py)
|
|
||||||
|
|
||||||
**Goal**: Orchestrator returns agent name alongside execution, supports streaming.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/core/orchestrator.py`:
|
|
||||||
- Add `orchestrate_v3(user_id, message, context, mode)` that:
|
|
||||||
1. Calls `classify_intent()` (unchanged) -> `agent_name`
|
|
||||||
2. Instantiates agent via registry
|
|
||||||
3. Returns `(agent_name, agent_instance)` — caller drives execution
|
|
||||||
- Add `orchestrate_v3_stream(user_id, message, context)` -> `AsyncGenerator` that:
|
|
||||||
1. Calls `classify_intent()` -> `agent_name`
|
|
||||||
2. Calls `agent.handle_stream()` (uses `_tool_loop_stream`)
|
|
||||||
3. Yields `(agent_name, token)` tuples — first yield includes agent name for domain detection
|
|
||||||
- Keep `orchestrate()` and `orchestrate_stream()` unchanged (backward compat for POST /chat).
|
|
||||||
|
|
||||||
**Files touched**: `app/core/orchestrator.py`
|
|
||||||
|
|
||||||
**Test**: Unit test with mocked LLM and mocked registry that verifies `orchestrate_v3_stream` yields `(agent_name, token)` pairs.
|
|
||||||
```
|
|
||||||
pytest tests/test_orchestrator_v3.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 3 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-3: add router refactor with streaming support (orchestrator.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 4 — Output Formatting Layer (NEW: output_formatter.py)
|
|
||||||
|
|
||||||
**Goal**: Home and Floating responses diverge at this layer only.
|
|
||||||
|
|
||||||
### Block Types (from Electron app components)
|
|
||||||
|
|
||||||
The LLM outputs a JSON block stream. Each block has a `type` field that maps to
|
|
||||||
an Electron renderer component. The server validates and forwards these blocks.
|
|
||||||
|
|
||||||
**Text block** — streamed immediately, word-by-word:
|
|
||||||
```json
|
|
||||||
{ "type": "text", "content": "Here's your task summary..." }
|
|
||||||
```
|
|
||||||
|
|
||||||
**Chart blocks** — buffered until complete, validated, sent as `stream_block`.
|
|
||||||
Chart types match shadcn/ui Recharts wrappers used in the Electron app:
|
|
||||||
```json
|
|
||||||
{ "type": "chart", "chartType": "<type>", "title": "...", "data": [...], "config": {...} }
|
|
||||||
```
|
|
||||||
Supported `chartType` values:
|
|
||||||
- `area` — Area chart (shadcn AreaChart)
|
|
||||||
- `bar` — Bar chart (shadcn BarChart)
|
|
||||||
- `line` — Line chart (shadcn LineChart)
|
|
||||||
- `pie` — Pie chart (shadcn PieChart)
|
|
||||||
- `radar` — Radar chart (shadcn RadarChart)
|
|
||||||
- `radial` — Radial/gauge chart (shadcn RadialChart)
|
|
||||||
|
|
||||||
`data` is an array of objects with keys matching the chart's dataKey config.
|
|
||||||
`config` follows the shadcn ChartConfig format: `{ [dataKey]: { label, color } }`.
|
|
||||||
|
|
||||||
**Entity blocks** — server serializes from `agent.tool_results` (not LLM-generated data):
|
|
||||||
```json
|
|
||||||
{ "type": "entity_ref", "entity": "task" }
|
|
||||||
```
|
|
||||||
The server resolves this by looking up the structured data from the agent's
|
|
||||||
tool call results and emitting a `stream_block` with the full entity data.
|
|
||||||
|
|
||||||
Supported entity types (matching Electron component types):
|
|
||||||
- `task` — TaskRow component (`TaskItem`: id, title, status, priority, assignee, dueDate, projectId, ...)
|
|
||||||
- `project` — Project card (id, name, clientId, status)
|
|
||||||
- `note` — Note card (id, title, createdAt, projectId)
|
|
||||||
- `timeline` — Timeline card (GanttTimeline: id, title, date, projectId, isAiSuggested, isApproved)
|
|
||||||
|
|
||||||
**Table block** — buffered, validated:
|
|
||||||
```json
|
|
||||||
{ "type": "table", "headers": ["Col1", "Col2"], "rows": [["val1", "val2"]] }
|
|
||||||
```
|
|
||||||
|
|
||||||
**Timeline block** — buffered, validated (renders via GanttChart component):
|
|
||||||
```json
|
|
||||||
{ "type": "timeline", "timelines": [{ "id": "...", "title": "...", "date": 1234567890 }] }
|
|
||||||
```
|
|
||||||
|
|
||||||
### Changes
|
|
||||||
|
|
||||||
- `app/core/output_formatter.py` (new file):
|
|
||||||
- `HomeFormatter`:
|
|
||||||
- Receives token stream from orchestrator
|
|
||||||
- Accumulates tokens into a JSON-aware buffer
|
|
||||||
- Detects block boundaries by `type` field:
|
|
||||||
- `text` -> yields `WsStreamText` immediately (streams content word-by-word)
|
|
||||||
- `chart` -> buffers until JSON complete, validates `chartType` against allowed set, yields `WsStreamBlock`
|
|
||||||
- `entity_ref` -> looks up data from `agent.tool_results`, serializes full entity, yields `WsStreamBlock`
|
|
||||||
- `table` -> buffers, validates headers/rows structure, yields `WsStreamBlock`
|
|
||||||
- `timeline` -> buffers, validates timeline objects, yields `WsStreamBlock`
|
|
||||||
- Invalid blocks are logged and skipped (never crash the stream)
|
|
||||||
- `FloatingFormatter`:
|
|
||||||
- Receives `agent_name` from orchestrator
|
|
||||||
- Maps agent name to domain (deterministic, by code — no LLM):
|
|
||||||
- `task_agent` -> `"tasks"`
|
|
||||||
- `timeline_agent` -> `"timelines"`
|
|
||||||
- `note_agent` -> `"notes"`
|
|
||||||
- `project_agent` -> `"projects"`
|
|
||||||
- Yields `WsFloatingDomain` immediately
|
|
||||||
- Then yields `WsStreamText` for all tokens (text-only, no blocks)
|
|
||||||
|
|
||||||
**Files touched**: `app/core/output_formatter.py` (new)
|
|
||||||
|
|
||||||
**Test**: Unit test that feeds a mock token stream through each formatter and asserts correct frame output sequence.
|
|
||||||
```
|
|
||||||
pytest tests/test_output_formatter.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 4 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-4: add output formatting layer (output_formatter.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 5 — Unified WS Handler (device_ws.py, chat.py, main.py)
|
|
||||||
|
|
||||||
**Goal**: Single multiplexed WebSocket handles device frames + Home/Floating chat.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/api/routes/device_ws.py`:
|
|
||||||
- Extend `_message_loop` dispatch to handle `home_request` and `floating_request`:
|
|
||||||
- On `home_request`: set `ws_context` executor, call `orchestrate_v3_stream`, pipe through `HomeFormatter`, send frames back on same socket.
|
|
||||||
- On `floating_request`: same, but pipe through `FloatingFormatter`.
|
|
||||||
- Wrap both in try/finally to clear `ws_context`.
|
|
||||||
- Each request gets a `request_id` (UUID) for frame correlation.
|
|
||||||
- Concurrent requests from same client are supported (each runs as an async task).
|
|
||||||
- `app/api/routes/chat.py`:
|
|
||||||
- Remove `chat_stream` WS endpoint and any related helper functions that were only used by it.
|
|
||||||
- Keep `POST /chat` endpoint unchanged (REST fallback).
|
|
||||||
- Clean up any unused imports.
|
|
||||||
- `app/main.py`:
|
|
||||||
- No change needed (device_ws router already registered).
|
|
||||||
|
|
||||||
**Files touched**: `app/api/routes/device_ws.py`, `app/api/routes/chat.py`, `app/main.py`
|
|
||||||
|
|
||||||
**Test**: Integration test with a WebSocket test client that:
|
|
||||||
1. Connects to `/api/v1/ws/device`
|
|
||||||
2. Sends `device_hello`
|
|
||||||
3. Sends `home_request` -> receives `stream_start`, `stream_text`*, `stream_end`
|
|
||||||
4. Sends `floating_request` -> receives `floating_domain`, `stream_text`*, `stream_end`
|
|
||||||
5. Verifies `tool_call`/`tool_result` round-trip still works during chat
|
|
||||||
```
|
|
||||||
pytest tests/test_ws_unified.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 5 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-5: unify ws handler (device_ws.py, chat.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 6 — Memory Models + Migration (models.py, alembic)
|
|
||||||
|
|
||||||
**Goal**: Database tables for 4-tier memory, with per-user encryption key.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/models.py`:
|
|
||||||
- Add `encryption_key` column to `User` model (Fernet key, generated on registration).
|
|
||||||
- Add `MemoryCore` model: `id, user_id, key, value_encrypted, updated_at`
|
|
||||||
- Add `MemoryAssociative` model: `id, user_id, content_encrypted, embedding (Vector(1536)), entity_type, entity_id, updated_at`
|
|
||||||
- Add `MemoryEpisodic` model: `id, user_id, summary_encrypted, session_id, created_at`
|
|
||||||
- Add `MemoryProactive` model: `id, user_id, pattern_encrypted, confidence, source, created_at`
|
|
||||||
- `alembic/versions/` — New migration adding the 4 memory tables + user encryption_key column.
|
|
||||||
- `app/api/routes/auth.py` — On user registration, generate and store a Fernet key.
|
|
||||||
|
|
||||||
**Files touched**: `app/models.py`, `alembic/versions/xxx_add_memory_tables.py`, `app/api/routes/auth.py`
|
|
||||||
|
|
||||||
**Test**: Run migration up/down, verify tables exist with correct columns.
|
|
||||||
```
|
|
||||||
alembic upgrade head && alembic downgrade -1 && alembic upgrade head
|
|
||||||
pytest tests/test_memory_models.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 6 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-6: add memory models and migration (models.py, alembic)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 7 — Memory Middleware (NEW: memory_middleware.py)
|
|
||||||
|
|
||||||
**Goal**: Enrich every Router call with memory context, store interactions after.
|
|
||||||
|
|
||||||
**Changes**:
|
|
||||||
- `app/core/memory_middleware.py` (new file):
|
|
||||||
- `MemoryMiddleware` class with:
|
|
||||||
- `enrich_context(user_id, message) -> dict` (pre-LLM):
|
|
||||||
1. Load core memory (user prefs) — always injected
|
|
||||||
2. Embed `message`, search `MemoryAssociative` via pgvector — top-k relevant
|
|
||||||
3. Fetch recent `MemoryEpisodic` entries — last N sessions
|
|
||||||
4. Fetch active `MemoryProactive` patterns — above confidence threshold
|
|
||||||
5. Return merged context dict
|
|
||||||
- `store_episode(user_id, session_id, message, response)` (post-LLM):
|
|
||||||
1. Summarize interaction (short LLM call or heuristic)
|
|
||||||
2. Encrypt and store in `MemoryEpisodic`
|
|
||||||
3. Embed interaction, encrypt and upsert in `MemoryAssociative`
|
|
||||||
- `update_core(user_id, key, value)` — explicit preference update
|
|
||||||
- All read/write operations encrypt/decrypt using the user's Fernet key from `User.encryption_key`
|
|
||||||
- `app/api/routes/device_ws.py` — Update `home_request` and `floating_request` handlers:
|
|
||||||
- Before orchestrator: `enriched = await memory.enrich_context(user_id, message)`
|
|
||||||
- After response complete: `await memory.store_episode(user_id, ...)`
|
|
||||||
|
|
||||||
**Files touched**: `app/core/memory_middleware.py` (new), `app/api/routes/device_ws.py`
|
|
||||||
|
|
||||||
**Test**: Unit test with seeded memory rows that verifies:
|
|
||||||
1. `enrich_context` returns core prefs + associative matches + episodic summaries
|
|
||||||
2. `store_episode` creates encrypted rows that can be decrypted with the user's key
|
|
||||||
3. End-to-end WS test: send `home_request`, verify memory enrichment is passed to orchestrator
|
|
||||||
```
|
|
||||||
pytest tests/test_memory_middleware.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**:
|
|
||||||
- [x] Step 7 complete
|
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
|
||||||
```
|
|
||||||
git commit -m "step-7: add memory middleware (memory_middleware.py, device_ws.py)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
| Step | Component | Effort | Depends On |
|
|
||||||
|------|-----------|--------|------------|
|
|
||||||
| 1 | WS Frame Protocol | Low | — |
|
|
||||||
| 2 | Agent Streaming | Medium | Step 1 |
|
|
||||||
| 3 | Router Refactor | Medium | Step 2 |
|
|
||||||
| 4 | Output Formatter | High | Steps 1, 3 |
|
|
||||||
| 5 | Unified WS Handler | High | Steps 1–4 |
|
|
||||||
| 6 | Memory Models | Medium | — |
|
|
||||||
| 7 | Memory Middleware | High | Steps 5, 6 |
|
|
||||||
|
|
||||||
Steps 1–5 form the streaming pipeline. Steps 6–7 form the memory system.
|
|
||||||
Step 6 can run in parallel with Steps 2–4 (no dependencies).
|
|
||||||
@@ -16,7 +16,7 @@ import re
|
|||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sqlalchemy import engine_from_config, pool
|
from sqlalchemy import pool
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
# Alembic Config object (gives access to alembic.ini values).
|
# Alembic Config object (gives access to alembic.ini values).
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
"""Initial schema: users, refresh_tokens, subscriptions.
|
||||||
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
|
||||||
|
|
||||||
Revision ID: 001
|
Revision ID: 001
|
||||||
Revises:
|
Revises:
|
||||||
@@ -28,18 +27,6 @@ def upgrade() -> None:
|
|||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
END $$;
|
END $$;
|
||||||
""")
|
""")
|
||||||
op.execute("""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
op.execute("""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
|
|
||||||
# ── users ─────────────────────────────────────────────────────────────
|
# ── users ─────────────────────────────────────────────────────────────
|
||||||
op.create_table(
|
op.create_table(
|
||||||
@@ -88,122 +75,10 @@ def upgrade() -> None:
|
|||||||
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||||
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||||
|
|
||||||
# ── storage_records ───────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"storage_records",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("table_name", sa.String(100), nullable=False),
|
|
||||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
|
||||||
sa.Column("checksum", sa.String(64), nullable=False),
|
|
||||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
|
||||||
|
|
||||||
# ── backup_metadata ───────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"backup_metadata",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
|
||||||
sa.Column("version", sa.Integer, nullable=False),
|
|
||||||
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
|
||||||
sa.Column("checksum", sa.String(64), nullable=False),
|
|
||||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
|
||||||
|
|
||||||
# ── plugins ───────────────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"plugins",
|
|
||||||
sa.Column("id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
|
||||||
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
|
||||||
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
|
||||||
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
|
||||||
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
|
||||||
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
|
||||||
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
|
||||||
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
|
||||||
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
|
||||||
sa.Column("rejection_reason", sa.Text, nullable=True),
|
|
||||||
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── plugin_installations ──────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"plugin_installations",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
|
||||||
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
|
||||||
|
|
||||||
# ── plugin_reviews ────────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"plugin_reviews",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
|
||||||
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
|
||||||
sa.Column("notes", sa.Text, nullable=True),
|
|
||||||
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
|
||||||
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
|
||||||
|
|
||||||
# ── revenue_events ────────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"revenue_events",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
|
||||||
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
|
||||||
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
|
||||||
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
|
||||||
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_table("revenue_events")
|
|
||||||
op.drop_table("plugin_reviews")
|
|
||||||
op.drop_table("plugin_installations")
|
|
||||||
op.drop_table("plugins")
|
|
||||||
op.drop_table("backup_metadata")
|
|
||||||
op.drop_table("storage_records")
|
|
||||||
op.drop_table("subscriptions")
|
op.drop_table("subscriptions")
|
||||||
op.drop_table("refresh_tokens")
|
op.drop_table("refresh_tokens")
|
||||||
op.drop_table("users")
|
op.drop_table("users")
|
||||||
|
|
||||||
op.execute("DROP TYPE IF EXISTS review_decision")
|
|
||||||
op.execute("DROP TYPE IF EXISTS plugin_status")
|
|
||||||
op.execute("DROP TYPE IF EXISTS billing_tier")
|
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||||
|
|||||||
@@ -1,92 +0,0 @@
|
|||||||
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
|
|
||||||
|
|
||||||
Revision ID: 002
|
|
||||||
Revises: 001
|
|
||||||
Create Date: 2026-03-03
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
revision: str = "002"
|
|
||||||
down_revision: Union[str, None] = "001"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
_SEED_PLUGINS = [
|
|
||||||
{
|
|
||||||
"id": "plugin-github-sync",
|
|
||||||
"name": "GitHub Sync",
|
|
||||||
"description": "Sync tasks with GitHub Issues and pull requests.",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"author_name": "Adiuva",
|
|
||||||
"category": "productivity",
|
|
||||||
"price_cents": 0,
|
|
||||||
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "plugin-slack-notify",
|
|
||||||
"name": "Slack Notifier",
|
|
||||||
"description": "Post task and timeline updates to Slack channels.",
|
|
||||||
"version": "1.2.0",
|
|
||||||
"author_name": "Adiuva",
|
|
||||||
"category": "communication",
|
|
||||||
"price_cents": 499,
|
|
||||||
"permissions": json.dumps(["read:tasks", "read:timelines"]),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "plugin-time-tracker",
|
|
||||||
"name": "Time Tracker",
|
|
||||||
"description": "Track time spent on tasks with automatic reporting.",
|
|
||||||
"version": "0.9.1",
|
|
||||||
"author_name": "Third Party",
|
|
||||||
"category": "productivity",
|
|
||||||
"price_cents": 999,
|
|
||||||
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
plugins = sa.table(
|
|
||||||
"plugins",
|
|
||||||
sa.column("id", sa.String),
|
|
||||||
sa.column("name", sa.String),
|
|
||||||
sa.column("description", sa.Text),
|
|
||||||
sa.column("version", sa.String),
|
|
||||||
sa.column("author_name", sa.String),
|
|
||||||
sa.column("category", sa.String),
|
|
||||||
sa.column("price_cents", sa.Integer),
|
|
||||||
sa.column("permissions", sa.Text),
|
|
||||||
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
|
|
||||||
sa.column("s3_package_key", sa.String),
|
|
||||||
sa.column("install_count", sa.Integer),
|
|
||||||
sa.column("avg_rating", sa.Float),
|
|
||||||
)
|
|
||||||
op.bulk_insert(plugins, _SEED_PLUGINS)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.execute(
|
|
||||||
"DELETE FROM plugins WHERE id IN ("
|
|
||||||
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
|
|
||||||
")"
|
|
||||||
)
|
|
||||||
@@ -14,7 +14,7 @@ from alembic import op
|
|||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
revision: str = "003"
|
revision: str = "003"
|
||||||
down_revision: Union[str, None] = "002"
|
down_revision: Union[str, None] = "001"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|||||||
54
alembic/versions/005_associative_pgvector.py
Normal file
54
alembic/versions/005_associative_pgvector.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Phase 1 — confirm pgvector activation on memory_associative.
|
||||||
|
|
||||||
|
Migration 004 created the embedding column as vector(1536) and added the
|
||||||
|
IVFFlat index. This migration is the Phase-1 checkpoint:
|
||||||
|
1. Ensures the pgvector extension is enabled (idempotent).
|
||||||
|
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
|
||||||
|
memory_associative_embedding_idx (creates it only if absent).
|
||||||
|
|
||||||
|
Revision ID: 005
|
||||||
|
Revises: 9a1f2d0b6c7e
|
||||||
|
Create Date: 2026-04-15
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "005"
|
||||||
|
down_revision: Union[str, None] = "e04100e88ace"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Ensure pgvector extension is enabled (also done in 004, idempotent).
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
|
||||||
|
# Ensure the canonical Phase-1 IVFFlat index exists.
|
||||||
|
# 004 may have created ix_memory_associative_embedding; this adds the
|
||||||
|
# Phase-1 name memory_associative_embedding_idx if it is missing.
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_indexes
|
||||||
|
WHERE tablename = 'memory_associative'
|
||||||
|
AND indexname = 'memory_associative_embedding_idx'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX memory_associative_embedding_idx
|
||||||
|
ON memory_associative
|
||||||
|
USING ivfflat (embedding vector_cosine_ops)
|
||||||
|
WITH (lists = 100);
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")
|
||||||
74
alembic/versions/006_memory_relations.py
Normal file
74
alembic/versions/006_memory_relations.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Add memory_relations table (Phase 3 — relational tier).
|
||||||
|
|
||||||
|
Revision ID: 006
|
||||||
|
Revises: 1f5975a4f3f4
|
||||||
|
Create Date: 2026-04-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "006"
|
||||||
|
down_revision: Union[str, None] = "1f5975a4f3f4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"memory_relations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("subject_label", sa.String(128), nullable=False),
|
||||||
|
sa.Column("subject_type", sa.String(32), nullable=False),
|
||||||
|
sa.Column("predicate", sa.String(64), nullable=False),
|
||||||
|
sa.Column("object_label", sa.String(128), nullable=False),
|
||||||
|
sa.Column("object_type", sa.String(32), nullable=False),
|
||||||
|
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
|
||||||
|
sa.Column(
|
||||||
|
"source_episode_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"memory_relations_user_subject_idx",
|
||||||
|
"memory_relations",
|
||||||
|
["user_id", "subject_label"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"memory_relations_user_predicate_idx",
|
||||||
|
"memory_relations",
|
||||||
|
["user_id", "predicate"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
|
||||||
|
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
|
||||||
|
op.drop_table("memory_relations")
|
||||||
38
alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
38
alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""add extraction_queue
|
||||||
|
|
||||||
|
Revision ID: 1f5975a4f3f4
|
||||||
|
Revises: 005
|
||||||
|
Create Date: 2026-04-16 17:26:25.790870
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '1f5975a4f3f4'
|
||||||
|
down_revision: Union[str, None] = '005'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'extraction_queue',
|
||||||
|
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
|
||||||
|
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
|
||||||
|
op.drop_table('extraction_queue')
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Deprecate backend agent config tables.
|
||||||
|
|
||||||
|
The Electron client is now the source of truth for agent configuration
|
||||||
|
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
||||||
|
billing checks and trigger/run logs only.
|
||||||
|
|
||||||
|
Revision ID: 9a1f2d0b6c7e
|
||||||
|
Revises: 818478c251dc
|
||||||
|
Create Date: 2026-03-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "9a1f2d0b6c7e"
|
||||||
|
down_revision: Union[str, None] = "818478c251dc"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = sa.inspect(bind)
|
||||||
|
existing = set(inspector.get_table_names())
|
||||||
|
|
||||||
|
if "cloud_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
|
||||||
|
if "local_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""Restore agent config tables and add agent_config column.
|
||||||
|
|
||||||
|
9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both
|
||||||
|
ORM models are still active. This migration recreates them with agent_config
|
||||||
|
added to local_agent_configs.
|
||||||
|
|
||||||
|
Revision ID: a3b9c0d1e2f3
|
||||||
|
Revises: 9a1f2d0b6c7e
|
||||||
|
Create Date: 2026-04-07 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "a3b9c0d1e2f3"
|
||||||
|
down_revision: Union[str, None] = "9a1f2d0b6c7e"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Recreate enum types (idempotent — they may already exist from migration 003)
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = sa.inspect(bind)
|
||||||
|
existing = set(inspector.get_table_names())
|
||||||
|
|
||||||
|
# ── local_agent_configs (with agent_config column) ────────────────────
|
||||||
|
if "local_agent_configs" not in existing:
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("agent_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
# ── cloud_agent_configs ───────────────────────────────────────────────
|
||||||
|
if "cloud_agent_configs" not in existing:
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
56
alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
56
alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
|
||||||
|
|
||||||
|
Revision ID: b4c0d1e2f3a4
|
||||||
|
Revises: a3b9c0d1e2f3
|
||||||
|
Create Date: 2026-04-10 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "b4c0d1e2f3a4"
|
||||||
|
down_revision: Union[str, None] = "a3b9c0d1e2f3"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── users: make password_hash nullable (social users have no password) ──
|
||||||
|
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
|
||||||
|
|
||||||
|
# ── users: add avatar_url ─────────────────────────────────────────────
|
||||||
|
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
|
||||||
|
|
||||||
|
# ── oauth_accounts ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"oauth_accounts",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("provider", sa.String(50), nullable=False),
|
||||||
|
sa.Column("provider_user_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("provider_email", sa.String(255), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
|
||||||
|
op.drop_table("oauth_accounts")
|
||||||
|
op.drop_column("users", "avatar_url")
|
||||||
|
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)
|
||||||
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Add onboarding_completed_at column to users table.
|
||||||
|
|
||||||
|
Revision ID: c5d1e2f3a4b5
|
||||||
|
Revises: b4c0d1e2f3a4
|
||||||
|
Create Date: 2026-04-11 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "c5d1e2f3a4b5"
|
||||||
|
down_revision: Union[str, None] = "b4c0d1e2f3a4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("users", "onboarding_completed_at")
|
||||||
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""avatar_url_varchar_to_text
|
||||||
|
|
||||||
|
Revision ID: e04100e88ace
|
||||||
|
Revises: c5d1e2f3a4b5
|
||||||
|
Create Date: 2026-04-13 09:13:06.733674
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'e04100e88ace'
|
||||||
|
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column('users', 'avatar_url',
|
||||||
|
existing_type=sa.VARCHAR(length=2048),
|
||||||
|
type_=sa.Text(),
|
||||||
|
existing_nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column('users', 'avatar_url',
|
||||||
|
existing_type=sa.Text(),
|
||||||
|
type_=sa.VARCHAR(length=2048),
|
||||||
|
existing_nullable=True)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Import all agent modules to trigger @registry.register decorators."""
|
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
||||||
|
|
||||||
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
194
app/agents/filesystem_agent.py
Normal file
194
app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||||
|
|
||||||
|
These tools delegate to the Electron client via ``execute_on_client()`` using
|
||||||
|
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
||||||
|
handles actual disk I/O and responds with ``tool_result`` frames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
# Max characters returned by read_file_content in journey (exploration) tools.
|
||||||
|
# The journey only needs to understand file structure, not full content.
|
||||||
|
_JOURNEY_READ_MAX_CHARS: int = 4000
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(path: str, base: str) -> str:
|
||||||
|
"""Resolve *path* against *base* when *path* is relative.
|
||||||
|
|
||||||
|
The LLM often passes ``"."`` meaning "the configured directory".
|
||||||
|
Without this, Electron resolves ``"."`` relative to its own CWD instead
|
||||||
|
of the user's chosen directory.
|
||||||
|
"""
|
||||||
|
if os.path.isabs(path):
|
||||||
|
return path
|
||||||
|
return str(Path(base) / path)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str:
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{path}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str:
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{path}' is empty or could not be read."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str:
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", path)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FILESYSTEM_TOOLS: list[Any] = [
|
||||||
|
list_directory,
|
||||||
|
read_file_content,
|
||||||
|
get_file_metadata,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def make_directory_tools(base_directory: str) -> list[Any]:
|
||||||
|
"""Return filesystem tools that resolve relative paths against *base_directory*.
|
||||||
|
|
||||||
|
Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target
|
||||||
|
directory upfront (e.g., journey setup sessions). Relative paths like ``"."``
|
||||||
|
from the LLM are resolved to the correct absolute path before being sent to
|
||||||
|
the Electron client, preventing it from falling back to its own CWD.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _compact_for_journey(raw: str) -> str:
|
||||||
|
"""Strip HTML noise and truncate for journey exploration.
|
||||||
|
|
||||||
|
The journey LLM only needs to understand file structure (headers,
|
||||||
|
first paragraphs). Full CSS/style blocks are pure noise that eat
|
||||||
|
up context window budget.
|
||||||
|
"""
|
||||||
|
text = re.sub(r"<style[^>]*>.*?</style>", "", raw, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
text = re.sub(r"<script[^>]*>.*?</script>", "", text, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)
|
||||||
|
if len(text) > _JOURNEY_READ_MAX_CHARS:
|
||||||
|
text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]"
|
||||||
|
return text
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str: # noqa: F811
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_path(path, base_directory)
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": resolved},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{resolved}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str: # noqa: F811
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_path(path, base_directory)
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": resolved},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{resolved}' is empty or could not be read."
|
||||||
|
return _compact_for_journey(content)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str: # noqa: F811
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_path(path, base_directory)
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": resolved},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", resolved)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [list_directory, read_file_content, get_file_metadata]
|
||||||
@@ -2,38 +2,31 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
from app.core.llm import embed
|
||||||
from app.core.llm import embed, get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
_UUID_RE = re.compile(
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -122,23 +115,15 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
NOTE_TOOLS: list[Any] = [
|
||||||
class NoteAgent(ChatAgent):
|
list_notes,
|
||||||
def get_name(self) -> str:
|
get_note,
|
||||||
return "note_agent"
|
create_note,
|
||||||
|
update_note,
|
||||||
|
delete_note,
|
||||||
|
]
|
||||||
|
|
||||||
def get_description(self) -> str:
|
NOTE_READ_TOOLS: list[Any] = [
|
||||||
return "Manages notes: list, get, create, update, delete"
|
list_notes,
|
||||||
|
get_note,
|
||||||
def get_tools(self) -> list[Any]:
|
]
|
||||||
return [list_notes, get_note, create_note, update_note, delete_note]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -2,32 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
|
||||||
"update, and archive projects in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: active, archived\n"
|
|
||||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
|
||||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
|
||||||
" derive it from context data — do not fabricate content\n"
|
|
||||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
|
||||||
" user wants a complete cross-client view including archived projects\n"
|
|
||||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
|
||||||
" list_projects if you only have a project name\n"
|
|
||||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
|
||||||
" only call delete_project when the user explicitly confirms deletion."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_projects(
|
async def list_projects(
|
||||||
@@ -137,30 +117,17 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
PROJECT_TOOLS: list[Any] = [
|
||||||
class ProjectAgent(ChatAgent):
|
list_projects,
|
||||||
def get_name(self) -> str:
|
list_all_projects,
|
||||||
return "project_agent"
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
|
|
||||||
def get_description(self) -> str:
|
PROJECT_READ_TOOLS: list[Any] = [
|
||||||
return "Manages projects: list, get, create, update, archive, delete"
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
def get_tools(self) -> list[Any]:
|
get_project,
|
||||||
return [
|
]
|
||||||
list_projects,
|
|
||||||
list_all_projects,
|
|
||||||
get_project,
|
|
||||||
create_project,
|
|
||||||
update_project,
|
|
||||||
delete_project,
|
|
||||||
]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -2,35 +2,23 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
_UUID_RE = re.compile(
|
||||||
"You are a task management assistant for a project workspace.\n"
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: todo, in_progress, done\n"
|
|
||||||
" - priority must be one of: high, medium, low\n"
|
|
||||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
|
||||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
|
||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
|
||||||
" did not explicitly request; 0 otherwise\n"
|
|
||||||
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -43,11 +31,12 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={
|
filters={
|
||||||
"projectId": project_id or None,
|
"projectId": normalized_project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
@@ -73,7 +62,6 @@ async def create_task(
|
|||||||
due_date: int = 0,
|
due_date: int = 0,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
is_approved: int = 0,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new task.
|
"""Create a new task.
|
||||||
title: task title (required)
|
title: task title (required)
|
||||||
@@ -84,7 +72,6 @@ async def create_task(
|
|||||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
project_id: optional UUID of the parent project
|
project_id: optional UUID of the parent project
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
is_approved: 0 until the user confirms; 1 when confirmed
|
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -98,7 +85,6 @@ async def create_task(
|
|||||||
"dueDate": due_date or None,
|
"dueDate": due_date or None,
|
||||||
"projectId": project_id or None,
|
"projectId": project_id or None,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
"isApproved": is_approved,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -118,12 +104,10 @@ async def update_task(
|
|||||||
assignees: str = "",
|
assignees: str = "",
|
||||||
due_date: int = -1,
|
due_date: int = -1,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_approved: int = -1,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update fields on an existing task. Only pass fields you want to change.
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
task_id: the task's UUID (required)
|
task_id: the task's UUID (required)
|
||||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
is_approved: -1 means unchanged; 0 or 1 sets the value
|
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
@@ -140,8 +124,6 @@ async def update_task(
|
|||||||
updates["dueDate"] = due_date or None
|
updates["dueDate"] = due_date or None
|
||||||
if project_id:
|
if project_id:
|
||||||
updates["projectId"] = project_id
|
updates["projectId"] = project_id
|
||||||
if is_approved != -1:
|
|
||||||
updates["isApproved"] = is_approved
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
@@ -209,8 +191,12 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result.get("row", {})
|
||||||
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
row_author = row.get("author", author)
|
||||||
|
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
||||||
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
|
row_comment_id = row.get("id", "unknown")
|
||||||
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -223,32 +209,19 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
# ── Agent ─────────────────────────────────────────────────────────────
|
# ── Agent ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
TASK_TOOLS: list[Any] = [
|
||||||
class TaskAgent(ChatAgent):
|
list_tasks,
|
||||||
def get_name(self) -> str:
|
create_task,
|
||||||
return "task_agent"
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
|
|
||||||
def get_description(self) -> str:
|
TASK_READ_TOOLS: list[Any] = [
|
||||||
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
list_tasks,
|
||||||
|
list_tasks_due_today,
|
||||||
def get_tools(self) -> list[Any]:
|
list_task_comments,
|
||||||
return [
|
]
|
||||||
list_tasks,
|
|
||||||
create_task,
|
|
||||||
update_task,
|
|
||||||
delete_task,
|
|
||||||
list_tasks_due_today,
|
|
||||||
list_task_comments,
|
|
||||||
add_task_comment,
|
|
||||||
delete_task_comment,
|
|
||||||
]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = get_llm()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -2,37 +2,31 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
_UUID_RE = re.compile(
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
|
||||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all timelines across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -47,14 +41,12 @@ async def create_timeline(
|
|||||||
title: str,
|
title: str,
|
||||||
date: int,
|
date: int,
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
is_approved: int = 0,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a project timeline (milestone).
|
"""Create a project timeline (milestone).
|
||||||
project_id: REQUIRED UUID of the parent project
|
project_id: REQUIRED UUID of the parent project
|
||||||
title: descriptive name for the milestone
|
title: descriptive name for the milestone
|
||||||
date: Unix timestamp in milliseconds
|
date: Unix timestamp in milliseconds
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
is_approved: 0 until the user confirms
|
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -64,7 +56,6 @@ async def create_timeline(
|
|||||||
"title": title,
|
"title": title,
|
||||||
"date": date,
|
"date": date,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
"isApproved": is_approved,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -76,20 +67,16 @@ async def update_timeline(
|
|||||||
timeline_id: str,
|
timeline_id: str,
|
||||||
title: str = "",
|
title: str = "",
|
||||||
date: int = -1,
|
date: int = -1,
|
||||||
is_approved: int = -1,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update a timeline. Only pass fields that should change.
|
"""Update a timeline. Only pass fields that should change.
|
||||||
timeline_id: UUID of the timeline (required)
|
timeline_id: UUID of the timeline (required)
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
updates["title"] = title
|
updates["title"] = title
|
||||||
if date != -1:
|
if date != -1:
|
||||||
updates["date"] = date
|
updates["date"] = date
|
||||||
if is_approved != -1:
|
|
||||||
updates["isApproved"] = is_approved
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
@@ -106,23 +93,33 @@ async def delete_timeline(timeline_id: str) -> str:
|
|||||||
return f"Timeline {timeline_id} deleted."
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@tool
|
||||||
class TimelineAgent(ChatAgent):
|
async def list_timelines_today() -> str:
|
||||||
def get_name(self) -> str:
|
"""List all timeline events (milestones) whose date falls on today (UTC)."""
|
||||||
return "timeline_agent"
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||||
|
end_ms = start_ms + 86_400_000 - 1
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="timelines",
|
||||||
|
filters={"dateFrom": start_ms, "dateTo": end_ms},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No timeline events today."
|
||||||
|
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages project timelines (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
TIMELINE_TOOLS: list[Any] = [
|
||||||
return [list_timelines, create_timeline, update_timeline, delete_timeline]
|
list_timelines,
|
||||||
|
list_timelines_today,
|
||||||
|
create_timeline,
|
||||||
|
update_timeline,
|
||||||
|
delete_timeline,
|
||||||
|
]
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
TIMELINE_READ_TOOLS: list[Any] = [
|
||||||
llm = get_llm()
|
list_timelines,
|
||||||
messages = [
|
list_timelines_today,
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
]
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -55,23 +55,49 @@ async def get_current_user(
|
|||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup — subscription row is the authoritative source.
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
||||||
|
# block local development when no Stripe subscription exists.
|
||||||
from app.models import Subscription, User # noqa: PLC0415
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str = result.scalar_one_or_none() or "free"
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
# Fetch name/surname from user row.
|
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
select(User.name, User.surname).where(User.id == user_id)
|
select(
|
||||||
|
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||||
|
User.password_hash,
|
||||||
|
).where(User.id == user_id)
|
||||||
)
|
)
|
||||||
user_row = user_result.one_or_none()
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
# Convert onboarding_completed_at to epoch ms (int) or None.
|
||||||
|
onboarding_ms: int | None = None
|
||||||
|
if user_row and user_row.onboarding_completed_at is not None:
|
||||||
|
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||||
|
|
||||||
|
# Load decrypted core memory.
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||||
|
|
||||||
|
memory_dict: dict[str, str] = {}
|
||||||
|
try:
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
blocks = await mw.list_core_blocks(user_id)
|
||||||
|
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||||
|
except Exception:
|
||||||
|
pass # Non-critical — return empty memory on failure
|
||||||
|
|
||||||
return UserProfile(
|
return UserProfile(
|
||||||
id=user_id,
|
id=user_id,
|
||||||
email=email,
|
email=email,
|
||||||
name=user_row.name if user_row else None,
|
name=user_row.name if user_row else None,
|
||||||
surname=user_row.surname if user_row else None,
|
surname=user_row.surname if user_row else None,
|
||||||
|
avatar_url=user_row.avatar_url if user_row else None,
|
||||||
|
has_password=bool(user_row.password_hash) if user_row else False,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
|
onboarding_completed_at=onboarding_ms,
|
||||||
|
memory=memory_dict,
|
||||||
) # type: ignore[arg-type]
|
) # type: ignore[arg-type]
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ that could reveal server-side prompt IP:
|
|||||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||||
- Exact-match known prompt fingerprints
|
- Exact-match known prompt fingerprints
|
||||||
|
|
||||||
Binary responses (storage blobs, backup data) are never touched — the
|
The middleware only activates for paths under /api/v1/chat.
|
||||||
middleware only activates for paths under /api/v1/chat.
|
|
||||||
|
|
||||||
Any sanitisation event is logged as a WARNING with the request path and the
|
Any sanitisation event is logged as a WARNING with the request path and the
|
||||||
names of the fields that were modified.
|
names of the fields that were modified.
|
||||||
|
|||||||
@@ -1,74 +1,71 @@
|
|||||||
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
|
||||||
|
|
||||||
Endpoints:
|
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||||
POST /agents/journey/start — start a new journey session
|
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||||
POST /agents/journey/message — continue the conversation
|
frames to the functions exported here.
|
||||||
|
|
||||||
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
|
||||||
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
1. FE sends ``journey_start`` frame with basic agent info (directory,
|
||||||
2. Server creates a session, calls the LLM with a contextual system prompt,
|
data_types, schedule).
|
||||||
and returns the first question.
|
2. Server creates an in-memory session, sets up a WS executor so the
|
||||||
3. Client sends follow-up messages to ``/message``.
|
setup LLM can use file-system tools, does a first directory scrape,
|
||||||
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
and sends back a ``journey_reply`` with the first question.
|
||||||
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
3. FE sends ``journey_message`` frames for each user reply.
|
||||||
5. Server parses the block, sets ``done=True``, and returns the template.
|
4. Server appends the user message, calls the LLM (which may read files
|
||||||
|
via tools), and sends back a ``journey_reply``.
|
||||||
The ``prompt_template`` from the final response is meant to be stored in
|
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
|
||||||
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
||||||
by the Electron client (via the agent CRUD endpoints).
|
6. Server parses and validates the JSON with Pydantic, sends
|
||||||
|
``journey_reply`` with ``done=True`` and the serialised config.
|
||||||
|
FE stores it locally.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.agents.filesystem_agent import make_directory_tools
|
||||||
from app.core.llm import get_llm
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
from app.db import get_session
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
from app.models import CloudAgentConfig, LocalAgentConfig
|
from app.schemas import AgentConfig
|
||||||
from app.schemas import (
|
|
||||||
JourneyMessageRequest,
|
|
||||||
JourneyResponse,
|
|
||||||
JourneyStartRequest,
|
|
||||||
UserProfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
|
||||||
|
|
||||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
|
||||||
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
_CONFIG_START = "AGENT_CONFIG_START"
|
||||||
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
_CONFIG_END = "AGENT_CONFIG_END"
|
||||||
|
|
||||||
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
# Minimum turns before we consider nudging the LLM to wrap up.
|
||||||
_MAX_TURNS: int = 5
|
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||||
|
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
|
||||||
|
_MAX_TURNS: int = 15
|
||||||
|
# Max tool-calling steps per LLM invocation.
|
||||||
|
_MAX_TOOL_STEPS: int = 6
|
||||||
|
|
||||||
# ── In-memory session store ───────────────────────────────────────────────
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _JourneySession:
|
class JourneySession:
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
agent_type: str # "local" | "cloud"
|
agent_type: str # "local" | "cloud"
|
||||||
|
directory: str
|
||||||
|
data_types: list[str]
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
system_prompt: str = ""
|
||||||
|
langfuse_prompt: Any = None
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
@@ -76,103 +73,182 @@ class _JourneySession:
|
|||||||
|
|
||||||
|
|
||||||
# session_id → session
|
# session_id → session
|
||||||
_sessions: dict[str, _JourneySession] = {}
|
_sessions: dict[str, JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||||
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||||
s = _sessions.get(session_id)
|
s = _sessions.get(session_id)
|
||||||
if s is None or s.is_expired():
|
if s is None or s.is_expired():
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
if s.user_id != user_id:
|
if s.user_id != user_id:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt builder ─────────────────────────────────────────────────
|
# ── System prompt ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_LOCAL_PREAMBLE = """\
|
_JOURNEY_SYSTEM_PROMPT = """\
|
||||||
What kind of files are in the directories you want to monitor? \
|
|
||||||
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
|
||||||
|
|
||||||
_CLOUD_PREAMBLE = """\
|
|
||||||
What kind of emails or messages should I look for? \
|
|
||||||
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
Your job is to understand what files the user has in their directory and produce a
|
||||||
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
structured AgentConfig JSON that the extraction agent will use as its instruction set.
|
||||||
|
|
||||||
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
You have access to file-system tools to explore the user's directory:
|
||||||
1. The type and format of the source content.
|
- list_directory: see folder structure and file names
|
||||||
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
- read_file_content: peek at a file's content
|
||||||
3. How fields should be mapped (e.g. email subject → task title).
|
- get_file_metadata: check file size, extension, dates
|
||||||
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
|
||||||
5. Any special handling, date extraction, or exclusions.
|
|
||||||
|
|
||||||
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
The user's configured directory is: {directory}
|
||||||
these exact markers on their own lines:
|
Target data types: {data_types}
|
||||||
|
|
||||||
{template_start}
|
## Your process
|
||||||
<the complete extraction prompt here>
|
|
||||||
{template_end}
|
|
||||||
|
|
||||||
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
### Step 1 — Explore the directory
|
||||||
and must return a JSON array of records in this shape:
|
Use list_directory and read_file_content to understand what types of files are present
|
||||||
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
(HTML emails, plain-text documents, CSVs, etc.).
|
||||||
|
|
||||||
|
### Step 2 — Identify content types
|
||||||
|
For each distinct file type found, decide:
|
||||||
|
- A short id (e.g. "email_html", "plain_text", "csv")
|
||||||
|
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
|
||||||
|
- A human-readable label and optional detection_hint
|
||||||
|
|
||||||
|
### Step 3 — Ask focused questions (one at a time)
|
||||||
|
Cover these topics based on what you discovered:
|
||||||
|
1. How to map content to entity types (task / note / timeline entry)
|
||||||
|
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
|
||||||
|
3. Priority or status rules (e.g. "urgent" in subject → high priority)
|
||||||
|
4. Date extraction (e.g. "by Friday" → dueDate)
|
||||||
|
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
||||||
|
|
||||||
|
### Step 4 — Produce the AgentConfig JSON
|
||||||
|
Once you are ≥ 90% confident, output the final config between these exact markers
|
||||||
|
(each on its own line):
|
||||||
|
|
||||||
|
{config_start}
|
||||||
|
{{
|
||||||
|
"content_types": [
|
||||||
|
{{
|
||||||
|
"id": "email_html",
|
||||||
|
"label": "Email HTML",
|
||||||
|
"detection_hint": "HTML file with From/To/Subject headers",
|
||||||
|
"preprocessing": "email_html",
|
||||||
|
"extraction_prompt": "Detailed extraction instructions for this content type..."
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
"global_rules": [
|
||||||
|
"If the file cannot be matched to any project, do not create any entity."
|
||||||
|
],
|
||||||
|
"data_types": {data_types_json}
|
||||||
|
}}
|
||||||
|
{config_end}
|
||||||
|
|
||||||
|
## Rules for the extraction_prompt field
|
||||||
|
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
|
||||||
|
- Include field mapping rules based on what you found in the directory
|
||||||
|
- Include priority/status/date rules if applicable
|
||||||
|
- Do NOT include projectId logic — the runner handles project assignment automatically
|
||||||
|
- Do NOT mention isAiSuggested — the runner always sets it to 1
|
||||||
|
|
||||||
|
## Constraints
|
||||||
|
- Never ask about projects, projectId, or how to link records to projects
|
||||||
|
- Never include projectId or project creation logic in the generated config
|
||||||
|
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
|
||||||
|
|
||||||
Rules for the generated template:
|
|
||||||
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
|
||||||
- Include concrete examples of mappings.
|
|
||||||
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
|
||||||
- Set isAiSuggested: true and isApproved: false on every record.
|
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
Begin by exploring the directory, then ask your first question.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
def _build_system_prompt(
|
||||||
source_description = (
|
directory: str,
|
||||||
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
data_types: list[str],
|
||||||
)
|
existing_config: str | None = None,
|
||||||
|
) -> tuple[str, Any]:
|
||||||
|
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
||||||
existing_section = (
|
existing_section = (
|
||||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
|
||||||
f"---\n{existing_template}\n---\n"
|
f"```json\n{existing_config}\n```\n"
|
||||||
if existing_template
|
if existing_config
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
return _SYSTEM_PROMPT_TEMPLATE.format(
|
template, prompt_obj = get_prompt_or_fallback(
|
||||||
source_description=source_description,
|
"journey_system", _JOURNEY_SYSTEM_PROMPT
|
||||||
template_start=_TEMPLATE_START,
|
|
||||||
template_end=_TEMPLATE_END,
|
|
||||||
existing_section=existing_section,
|
|
||||||
max_turns=_MAX_TURNS,
|
|
||||||
)
|
)
|
||||||
|
compiled = compile_prompt(
|
||||||
|
template,
|
||||||
|
prompt_obj,
|
||||||
|
directory=directory,
|
||||||
|
data_types=", ".join(data_types),
|
||||||
|
data_types_json=json.dumps(data_types),
|
||||||
|
config_start=_CONFIG_START,
|
||||||
|
config_end=_CONFIG_END,
|
||||||
|
existing_section=existing_section,
|
||||||
|
)
|
||||||
|
return compiled, prompt_obj
|
||||||
|
|
||||||
|
|
||||||
def _first_question(agent_type: str) -> str:
|
# ── AgentConfig extraction ────────────────────────────────────────────────
|
||||||
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
|
||||||
|
|
||||||
|
|
||||||
# ── Template extraction ───────────────────────────────────────────────────
|
def _extract_agent_config(text: str) -> str | None:
|
||||||
|
"""Return validated AgentConfig JSON string from between markers, or None.
|
||||||
|
|
||||||
|
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
||||||
def _extract_template(text: str) -> str | None:
|
returning. Returns None if markers are absent or JSON is invalid.
|
||||||
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
"""
|
||||||
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
if _CONFIG_START not in text or _CONFIG_END not in text:
|
||||||
|
return None
|
||||||
|
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
|
||||||
|
end_idx = text.index(_CONFIG_END)
|
||||||
|
raw = text[start_idx:end_idx].strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = AgentConfig.model_validate_json(raw)
|
||||||
|
return parsed.model_dump_json()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
|
||||||
return None
|
return None
|
||||||
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
|
||||||
end_idx = text.index(_TEMPLATE_END)
|
|
||||||
return text[start_idx:end_idx].strip() or None
|
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call ─────────────────────────────────────────────────────────────
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
def _as_text(content: Any) -> str:
|
||||||
"""Build LangChain messages from history and invoke the LLM."""
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm_with_tools(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tools: list[Any],
|
||||||
|
*,
|
||||||
|
user_id: str = "",
|
||||||
|
session_id: str = "",
|
||||||
|
langfuse_prompt: Any = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
|
continue until a final text response is produced.
|
||||||
|
"""
|
||||||
|
lf = get_langfuse()
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
for turn in history:
|
for turn in history:
|
||||||
if turn["role"] == "user":
|
if turn["role"] == "user":
|
||||||
@@ -180,138 +256,258 @@ async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
|||||||
else:
|
else:
|
||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_llm(model=None, temperature=0.4)
|
llm = get_agent_llm("setup", temperature=0.4)
|
||||||
response = await llm.ainvoke(messages)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
return response.content # type: ignore[return-value]
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
||||||
|
_lf_ctx.__enter__()
|
||||||
|
|
||||||
|
_span_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name="journey-setup",
|
||||||
|
input=history[-1]["content"] if history else "",
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
for step in range(_MAX_TOOL_STEPS):
|
||||||
|
_gen_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="journey-setup-llm",
|
||||||
|
model=model_for_agent("setup"),
|
||||||
|
prompt=langfuse_prompt,
|
||||||
|
input=messages,
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
if _gen_ctx:
|
||||||
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
resp_text = _as_text(response.content)
|
||||||
|
|
||||||
|
# Guard against empty responses (e.g. model returned finish_reason
|
||||||
|
# 'error' which LiteLLM maps to 'stop' with empty content).
|
||||||
|
if not response.tool_calls and not resp_text.strip():
|
||||||
|
logger.warning(
|
||||||
|
"agent_setup: journey LLM returned empty response at step %d — retrying",
|
||||||
|
step,
|
||||||
|
)
|
||||||
|
# Drop the empty AIMessage so we don't pollute history, and retry.
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
if _span:
|
||||||
|
_span.update(output=resp_text)
|
||||||
|
return resp_text
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:800],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max steps.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
final_text = _as_text(final.content)
|
||||||
|
if _span:
|
||||||
|
_span.update(output=final_text)
|
||||||
|
return final_text or (
|
||||||
|
"Sorry, I had trouble processing the files. "
|
||||||
|
"Could you try again? If the issue persists, the files might be too large for me to analyse."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if _span_ctx:
|
||||||
|
_span_ctx.__exit__(None, None, None)
|
||||||
|
_lf_ctx.__exit__(None, None, None)
|
||||||
|
if lf:
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
|
||||||
# ── Existing-config loader ────────────────────────────────────────────────
|
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _load_existing_template(
|
async def handle_journey_start(
|
||||||
agent_id: str,
|
|
||||||
user_id: str,
|
user_id: str,
|
||||||
db: AsyncSession,
|
frame: dict[str, Any],
|
||||||
) -> str | None:
|
) -> dict[str, Any]:
|
||||||
"""Return the prompt_template of an existing agent config, or None."""
|
"""Handle a ``journey_start`` WS frame.
|
||||||
# Try local first, then cloud.
|
|
||||||
local_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
local = local_result.scalar_one_or_none()
|
|
||||||
if local is not None:
|
|
||||||
return local.prompt_template
|
|
||||||
|
|
||||||
cloud_result = await db.execute(
|
Creates a session, runs the setup LLM with directory exploration,
|
||||||
select(CloudAgentConfig).where(
|
and returns the ``journey_reply`` payload.
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud = cloud_result.scalar_one_or_none()
|
|
||||||
return cloud.prompt_template if cloud is not None else None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def start_journey(
|
|
||||||
body: JourneyStartRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Start a new Chatbot Journey session.
|
|
||||||
|
|
||||||
If ``agent_id`` is provided the session is pre-seeded with the existing
|
|
||||||
agent's ``prompt_template`` so the user can refine it.
|
|
||||||
"""
|
"""
|
||||||
# Load existing template (may be None).
|
agent_type = frame.get("agent_type", "local")
|
||||||
existing_template: str | None = None
|
directory = frame.get("directory", "")
|
||||||
if body.agent_id:
|
data_types = frame.get("data_types", [])
|
||||||
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
existing_config = frame.get("existing_config")
|
||||||
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
|
||||||
# the user may be starting a fresh journey for a not-yet-persisted config).
|
|
||||||
|
|
||||||
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
# Use the session_id provided by the FE so the reply matches the
|
||||||
first_question = _first_question(body.agent_type)
|
# listener key; fall back to a generated one if absent.
|
||||||
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
|
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
|
||||||
|
|
||||||
session_id = str(uuid.uuid4())
|
session = JourneySession(
|
||||||
session = _JourneySession(
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=current_user.id,
|
user_id=user_id,
|
||||||
agent_type=body.agent_type,
|
agent_type=agent_type,
|
||||||
# Seed history with the AI's first question so it stays consistent.
|
directory=directory,
|
||||||
history=[{"role": "assistant", "content": first_question}],
|
data_types=data_types,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
)
|
)
|
||||||
# Store the system prompt inside the session for reuse in /message.
|
|
||||||
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
# Seed with an initial user message — some providers require at least one
|
||||||
|
# user/input message to be present.
|
||||||
|
seed_history: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
|
]
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history=seed_history,
|
||||||
|
tools=make_directory_tools(directory),
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.extend(seed_history)
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
_sessions[session_id] = session
|
_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
logger.info(
|
||||||
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
"agent_setup: journey session %s started for user %s (directory=%s)",
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the LLM produced the config on the first turn (unlikely but possible).
|
||||||
|
agent_config = _extract_agent_config(ai_reply)
|
||||||
|
done = agent_config is not None
|
||||||
|
|
||||||
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def send_journey_message(
|
|
||||||
body: JourneyMessageRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Send a message in an existing Chatbot Journey session.
|
|
||||||
|
|
||||||
The server appends the user's message to the conversation history,
|
|
||||||
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
|
||||||
``prompt_template`` block the response includes ``done=True`` and the
|
|
||||||
extracted template.
|
|
||||||
"""
|
|
||||||
session = _get_session(body.session_id, current_user.id)
|
|
||||||
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Append user turn to history.
|
|
||||||
session.history.append({"role": "user", "content": body.message})
|
|
||||||
|
|
||||||
# Call the LLM with the full conversation so far.
|
|
||||||
ai_reply = await _call_llm(system_prompt, session.history)
|
|
||||||
|
|
||||||
# Append AI turn.
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
|
|
||||||
# Check if the LLM produced the final template.
|
|
||||||
prompt_template = _extract_template(ai_reply)
|
|
||||||
done = prompt_template is not None
|
|
||||||
|
|
||||||
# Strip the sentinel markers from the message shown to the user.
|
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
)
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
if done:
|
return {
|
||||||
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
"type": "journey_reply",
|
||||||
# Clean up the session immediately on completion.
|
"session_id": session_id,
|
||||||
_sessions.pop(body.session_id, None)
|
"message": display_message,
|
||||||
else:
|
"done": done,
|
||||||
# Nudge the LLM to wrap up after max turns.
|
"agent_config": agent_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_message(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_message`` WS frame.
|
||||||
|
|
||||||
|
Appends the user message, calls the LLM, and returns the
|
||||||
|
``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
message = frame.get("message", "")
|
||||||
|
|
||||||
|
session = get_journey_session(session_id, user_id)
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
|
"done": True,
|
||||||
|
"agent_config": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Append user turn.
|
||||||
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
# Call the LLM with tools.
|
||||||
|
session_tools = make_directory_tools(session.directory)
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=session_tools,
|
||||||
|
user_id=session.user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
langfuse_prompt=session.langfuse_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
# Check if the LLM produced the final config.
|
||||||
|
agent_config = _extract_agent_config(ai_reply)
|
||||||
|
done = agent_config is not None
|
||||||
|
|
||||||
|
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
|
||||||
|
if not done:
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
if turns >= _MAX_TURNS:
|
if turns >= _MAX_TURNS:
|
||||||
# Add a system-level nudge as a hidden user message.
|
nudge_content = (
|
||||||
session.history.append({
|
"[System: You have enough information. Please generate the final "
|
||||||
"role": "user",
|
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
||||||
"content": (
|
)
|
||||||
"[System: You have enough information. Please generate the final "
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
|
||||||
),
|
|
||||||
})
|
|
||||||
|
|
||||||
return JourneyResponse(
|
nudge_reply = await _call_llm_with_tools(
|
||||||
session_id=body.session_id,
|
system_prompt=session.system_prompt,
|
||||||
message=display_message,
|
history=session.history,
|
||||||
done=done,
|
tools=session_tools,
|
||||||
prompt_template=prompt_template,
|
user_id=session.user_id,
|
||||||
)
|
session_id=session_id,
|
||||||
|
langfuse_prompt=session.langfuse_prompt,
|
||||||
|
)
|
||||||
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
|
agent_config = _extract_agent_config(nudge_reply)
|
||||||
|
if agent_config is not None:
|
||||||
|
done = True
|
||||||
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||||
|
if _CONFIG_START in ai_reply
|
||||||
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"agent_config": agent_config,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,48 +1,42 @@
|
|||||||
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
"""Agent routes.
|
||||||
|
|
||||||
Endpoints:
|
Backend responsibilities are intentionally minimal:
|
||||||
GET /agents/catalog — hardcoded agent type catalog
|
GET /agents/catalog — static catalog for UI display
|
||||||
GET /agents/local — list user's local agent configs
|
POST /agents/can-create — billing eligibility check
|
||||||
POST /agents/local — create local agent (tier-gated)
|
POST /agents/trigger — trigger a local agent run
|
||||||
PUT /agents/local/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
Agent configuration is owned by the Electron app and is not persisted
|
||||||
GET /agents/cloud — list user's cloud agent configs
|
in backend agent-config tables.
|
||||||
POST /agents/cloud — create cloud agent (tier-gated)
|
|
||||||
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
|
||||||
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
|
||||||
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
import logging
|
||||||
from typing import Any
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from pydantic import BaseModel
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy import func, or_, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
from app.core.agent_runner import is_agent_running, run_local_agent
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
AgentCatalogItem,
|
AgentCatalogItem,
|
||||||
|
AgentCreationCheckRequest,
|
||||||
|
AgentCreationCheckResponse,
|
||||||
AgentRunLogResponse,
|
AgentRunLogResponse,
|
||||||
CloudAgentConfigCreate,
|
AgentTriggerRequest,
|
||||||
CloudAgentConfigResponse,
|
|
||||||
CloudAgentConfigUpdate,
|
|
||||||
LocalAgentConfigCreate,
|
|
||||||
LocalAgentConfigResponse,
|
|
||||||
LocalAgentConfigUpdate,
|
|
||||||
UserProfile,
|
UserProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||||
|
|
||||||
|
|
||||||
@@ -56,39 +50,21 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
# ── Model → schema converters ─────────────────────────────────────────
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
"task": "tasks", "tasks": "tasks",
|
||||||
return LocalAgentConfigResponse(
|
"note": "notes", "notes": "notes",
|
||||||
id=a.id,
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
name=a.name,
|
"project": "projects", "projects": "projects",
|
||||||
device_id=a.device_id,
|
}
|
||||||
directory_paths=a.directory_paths,
|
seen: set[str] = set()
|
||||||
data_types=a.data_types,
|
result: list[str] = []
|
||||||
prompt_template=a.prompt_template,
|
for v in values:
|
||||||
file_extensions=a.file_extensions,
|
mapped = normalize.get(v)
|
||||||
schedule_cron=a.schedule_cron,
|
if mapped and mapped not in seen:
|
||||||
enabled=a.enabled,
|
seen.add(mapped)
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
result.append(mapped)
|
||||||
created_at=_dt_ms(a.created_at),
|
return result
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
|
||||||
return CloudAgentConfigResponse(
|
|
||||||
id=a.id,
|
|
||||||
provider=a.provider, # type: ignore[arg-type]
|
|
||||||
name=a.name,
|
|
||||||
data_types=a.data_types,
|
|
||||||
prompt_template=a.prompt_template,
|
|
||||||
schedule_cron=a.schedule_cron,
|
|
||||||
filter_config=a.filter_config,
|
|
||||||
enabled=a.enabled,
|
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
|
||||||
created_at=_dt_ms(a.created_at),
|
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
@@ -105,77 +81,42 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Ownership-checked lookups ─────────────────────────────────────────
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
|
||||||
async def _get_local_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> LocalAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_cloud_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> CloudAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier limit helper ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
|
||||||
"""Return combined enabled local + cloud agent count for the user."""
|
|
||||||
local_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(LocalAgentConfig.id)).where(
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
LocalAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
cloud_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(CloudAgentConfig.id)).where(
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
CloudAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
return local_count + cloud_count
|
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
if limit != -1 and current_count >= limit:
|
if limit != -1 and current_count >= limit:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
)
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
async def _enforce_run_frequency(
|
||||||
|
tier: str,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return # unlimited
|
||||||
|
|
||||||
class _RunsPage(BaseModel):
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
total: int
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
page: int
|
)
|
||||||
limit: int
|
result = await db.execute(
|
||||||
items: list[AgentRunLogResponse]
|
select(func.count(AgentRunLog.id)).where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
@@ -209,229 +150,68 @@ async def get_agent_catalog(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent CRUD ──────────────────────────────────────────────────
|
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
||||||
|
async def can_create_agent(
|
||||||
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
body: AgentCreationCheckRequest,
|
||||||
async def list_local_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
) -> AgentCreationCheckResponse:
|
||||||
) -> list[LocalAgentConfigResponse]:
|
"""Check if the user can create one more agent based on billing tier.
|
||||||
"""List all local directory agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_local_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
Since configuration is client-owned, the Electron app sends its current
|
||||||
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
active agent count and the backend applies tier limits.
|
||||||
async def create_local_agent(
|
|
||||||
body: LocalAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Create a new local directory agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
"""
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||||
agent = LocalAgentConfig(
|
allowed = limit == -1 or body.active_agents < limit
|
||||||
user_id=current_user.id,
|
return AgentCreationCheckResponse(
|
||||||
name=body.name,
|
allowed=allowed,
|
||||||
device_id=body.device_id,
|
tier=current_user.tier,
|
||||||
directory_paths=body.directory_paths,
|
active_agents=body.active_agents,
|
||||||
data_types=body.data_types,
|
limit=limit,
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
file_extensions=body.file_extensions,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
)
|
)
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
async def update_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: LocalAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Partially update a local agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/local/{agent_id}", response_model=dict)
|
|
||||||
async def delete_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
|
||||||
async def list_cloud_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[CloudAgentConfigResponse]:
|
|
||||||
"""List all cloud connector agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_cloud_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_cloud_agent(
|
|
||||||
body: CloudAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Create a new cloud connector agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
|
||||||
agent = CloudAgentConfig(
|
|
||||||
user_id=current_user.id,
|
|
||||||
provider=body.provider,
|
|
||||||
name=body.name,
|
|
||||||
data_types=body.data_types,
|
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
oauth_token_encrypted=body.oauth_token_encrypted,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
filter_config=body.filter_config,
|
|
||||||
)
|
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
|
||||||
async def update_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: CloudAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Partially update a cloud agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/cloud/{agent_id}", response_model=dict)
|
|
||||||
async def delete_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Run logs ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/runs", response_model=_RunsPage)
|
|
||||||
async def list_run_logs(
|
|
||||||
agent_id: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
limit: int = Query(default=20, ge=1, le=100),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> _RunsPage:
|
|
||||||
"""Return paginated run logs for the authenticated user.
|
|
||||||
|
|
||||||
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
|
||||||
"""
|
|
||||||
base_filter = [AgentRunLog.user_id == current_user.id]
|
|
||||||
if agent_id:
|
|
||||||
base_filter.append(AgentRunLog.agent_id == agent_id)
|
|
||||||
|
|
||||||
total = (
|
|
||||||
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
|
||||||
).scalar_one()
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(AgentRunLog)
|
|
||||||
.where(*base_filter)
|
|
||||||
.order_by(AgentRunLog.started_at.desc())
|
|
||||||
.offset((page - 1) * limit)
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
|
||||||
|
|
||||||
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Manual trigger stub ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
|
||||||
async def trigger_agent_run(
|
async def trigger_agent_run(
|
||||||
agent_id: str,
|
body: AgentTriggerRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> AgentRunLogResponse:
|
) -> AgentRunLogResponse:
|
||||||
"""Manually trigger an agent run.
|
"""Trigger a local agent run using client-provided configuration."""
|
||||||
|
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||||
|
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||||
|
|
||||||
Looks up the agent config (local or cloud) by ID with ownership check,
|
last_run_dt = (
|
||||||
creates a run log entry with ``status="running"``, and returns it.
|
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
|
||||||
|
if body.last_run_at
|
||||||
Actual dispatch to the agent runner is wired in Step 3.4 once
|
else None
|
||||||
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
)
|
||||||
"""
|
config = LocalAgentConfig(
|
||||||
# Determine agent type by trying local first, then cloud.
|
id=str(uuid.uuid4()),
|
||||||
# Keep the full config object so we can pass it to the agent runner.
|
user_id=current_user.id,
|
||||||
local_config: LocalAgentConfig | None = None
|
device_id=body.device_id,
|
||||||
cloud_config: CloudAgentConfig | None = None
|
name="Local Directory Monitor",
|
||||||
|
directory_paths=[body.directory],
|
||||||
local_result = await db.execute(
|
data_types=_to_data_types(body.what_to_extract),
|
||||||
select(LocalAgentConfig).where(
|
prompt_template=body.custom_agent_prompt or "",
|
||||||
LocalAgentConfig.id == agent_id,
|
agent_config=body.agent_config,
|
||||||
LocalAgentConfig.user_id == current_user.id,
|
file_extensions=[],
|
||||||
)
|
schedule_cron=body.batch_interval,
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=last_run_dt,
|
||||||
)
|
)
|
||||||
local_config = local_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if local_config is not None:
|
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||||
agent_type = "local"
|
stable_agent_id = body.agent_id or config.id
|
||||||
else:
|
|
||||||
cloud_result = await db.execute(
|
if is_agent_running(stable_agent_id):
|
||||||
select(CloudAgentConfig).where(
|
raise HTTPException(
|
||||||
CloudAgentConfig.id == agent_id,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
CloudAgentConfig.user_id == current_user.id,
|
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
||||||
)
|
|
||||||
)
|
)
|
||||||
cloud_config = cloud_result.scalar_one_or_none()
|
|
||||||
if cloud_config is not None:
|
|
||||||
agent_type = "cloud"
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = AgentRunLog(
|
||||||
agent_id=agent_id,
|
agent_id=stable_agent_id,
|
||||||
agent_type=agent_type,
|
agent_type="local",
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
status="running",
|
status="running",
|
||||||
)
|
)
|
||||||
@@ -439,14 +219,14 @@ async def trigger_agent_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
# Dispatch the run as a background task — returns 202 immediately.
|
run_context = {
|
||||||
if agent_type == "local" and local_config is not None:
|
"type": "agent_batch",
|
||||||
asyncio.create_task(
|
"run_id": run_log.id,
|
||||||
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
"agent_id": stable_agent_id,
|
||||||
)
|
}
|
||||||
elif agent_type == "cloud" and cloud_config is not None:
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|||||||
@@ -1,34 +1,68 @@
|
|||||||
"""Auth routes: register, login, refresh, me.
|
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
|
||||||
|
|
||||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||||
SHA-256 hashes so plaintext never reaches the DB.
|
SHA-256 hashes so plaintext never reaches the DB.
|
||||||
|
|
||||||
|
OAuth (Google):
|
||||||
|
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
|
||||||
|
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
|
import urllib.parse
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import RefreshToken, User
|
from app.models import OAuthAccount, RefreshToken, User
|
||||||
from app.schemas import AuthTokens, UserProfile
|
from app.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth provider registry ───────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_google_provider() -> GoogleOAuthProvider:
|
||||||
|
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
"Google login is not configured on this server",
|
||||||
|
)
|
||||||
|
return GoogleOAuthProvider(
|
||||||
|
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
|
||||||
|
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||||
|
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PROVIDERS = {"google": _get_google_provider}
|
||||||
|
|
||||||
|
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
|
||||||
|
# Production note: replace with Redis for multi-process deployments.
|
||||||
|
_pending_states: dict[str, tuple[str, float]] = {}
|
||||||
|
_STATE_TTL_SECONDS = 600 # 10 minutes
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ─────────────────────────────────────────────────
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -231,5 +265,531 @@ async def update_profile(
|
|||||||
email=user.email,
|
email=user.email,
|
||||||
name=user.name,
|
name=user.name,
|
||||||
surname=user.surname,
|
surname=user.surname,
|
||||||
|
avatar_url=user.avatar_url,
|
||||||
tier=current_user.tier,
|
tier=current_user.tier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth helpers ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
|
||||||
|
"""Create a refresh token row and return (plain_token, AuthTokens)."""
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return plain_token, AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth request/response schemas ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _OAuthAuthorizeResponse(BaseModel):
|
||||||
|
url: str
|
||||||
|
state: str
|
||||||
|
|
||||||
|
|
||||||
|
class _OAuthCallbackRequest(BaseModel):
|
||||||
|
code: str
|
||||||
|
state: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth routes ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/oauth/{provider}/web-callback",
|
||||||
|
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def oauth_web_callback(
|
||||||
|
provider: Literal["google"],
|
||||||
|
code: str,
|
||||||
|
state: str,
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Google redirects here after user consent.
|
||||||
|
|
||||||
|
This endpoint immediately redirects to the Electron deep-link URI so the
|
||||||
|
desktop app receives the authorization code. It is intentionally simple —
|
||||||
|
no state validation here (the Electron app + backend callback do that).
|
||||||
|
|
||||||
|
Registered in Google Cloud Console as:
|
||||||
|
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
|
||||||
|
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
|
||||||
|
"""
|
||||||
|
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
|
||||||
|
deep_link = f"adiuvai://oauth/callback?{params}"
|
||||||
|
return RedirectResponse(url=deep_link, status_code=302)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/oauth/{provider}/authorize",
|
||||||
|
response_model=_OAuthAuthorizeResponse,
|
||||||
|
summary="Start OAuth flow — returns the provider consent-screen URL",
|
||||||
|
)
|
||||||
|
async def oauth_authorize(
|
||||||
|
provider: Literal["google"],
|
||||||
|
) -> _OAuthAuthorizeResponse:
|
||||||
|
"""Generate a PKCE state + code_challenge and return the authorization URL.
|
||||||
|
|
||||||
|
The client opens this URL in the system browser. After the user grants
|
||||||
|
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
|
||||||
|
with ``code`` and ``state`` query params. The client then calls
|
||||||
|
``POST /auth/oauth/{provider}/callback`` with those values.
|
||||||
|
"""
|
||||||
|
provider_factory = _PROVIDERS.get(provider)
|
||||||
|
if provider_factory is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
oauth_provider = provider_factory()
|
||||||
|
state = str(uuid.uuid4())
|
||||||
|
code_verifier, code_challenge = generate_pkce_pair()
|
||||||
|
|
||||||
|
# Purge expired states to prevent unbounded growth.
|
||||||
|
now = time.time()
|
||||||
|
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
|
||||||
|
for s in expired:
|
||||||
|
del _pending_states[s]
|
||||||
|
|
||||||
|
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
|
||||||
|
|
||||||
|
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
|
||||||
|
return _OAuthAuthorizeResponse(url=url, state=state)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/oauth/{provider}/callback",
|
||||||
|
response_model=AuthTokens,
|
||||||
|
summary="Complete OAuth flow — exchange code and issue JWT tokens",
|
||||||
|
)
|
||||||
|
async def oauth_callback(
|
||||||
|
provider: Literal["google"],
|
||||||
|
body: _OAuthCallbackRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Validate state, exchange the authorization code, and sign in (or register) the user.
|
||||||
|
|
||||||
|
Resolution order:
|
||||||
|
1. ``oauth_accounts`` row match → existing user, log in.
|
||||||
|
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
|
||||||
|
3. No match → create new user (password_hash=None, avatar from provider).
|
||||||
|
"""
|
||||||
|
provider_factory = _PROVIDERS.get(provider)
|
||||||
|
if provider_factory is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
# Validate state (CSRF protection).
|
||||||
|
now = time.time()
|
||||||
|
entry = _pending_states.pop(body.state, None)
|
||||||
|
if entry is None or entry[1] < now:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
||||||
|
|
||||||
|
code_verifier, _ = entry
|
||||||
|
|
||||||
|
oauth_provider = provider_factory()
|
||||||
|
|
||||||
|
# Exchange code for tokens.
|
||||||
|
try:
|
||||||
|
token_data = await oauth_provider.exchange_code(
|
||||||
|
code=body.code,
|
||||||
|
code_verifier=code_verifier,
|
||||||
|
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token_google = token_data.get("access_token")
|
||||||
|
if not access_token_google:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
|
||||||
|
|
||||||
|
# Fetch user identity.
|
||||||
|
try:
|
||||||
|
userinfo = await oauth_provider.get_userinfo(access_token_google)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
|
||||||
|
|
||||||
|
# ── Resolution order ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
# 1. Existing OAuth link?
|
||||||
|
oauth_result = await db.execute(
|
||||||
|
select(OAuthAccount).where(
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
OAuthAccount.provider_user_id == userinfo.provider_user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
oauth_account = oauth_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if oauth_account is not None:
|
||||||
|
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
|
||||||
|
user = user_result.scalar_one()
|
||||||
|
# Backfill avatar if the user doesn't have one yet.
|
||||||
|
if user.avatar_url is None and userinfo.avatar_url:
|
||||||
|
user.avatar_url = userinfo.avatar_url
|
||||||
|
await db.commit()
|
||||||
|
plain_token, tokens = await _issue_refresh_token(user, db)
|
||||||
|
await db.commit()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# 2. Email match with a verified Google email → link accounts.
|
||||||
|
if userinfo.email_verified:
|
||||||
|
email_result = await db.execute(select(User).where(User.email == userinfo.email))
|
||||||
|
existing_user = email_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing_user is not None:
|
||||||
|
new_link = OAuthAccount(
|
||||||
|
user_id=existing_user.id,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=userinfo.provider_user_id,
|
||||||
|
provider_email=userinfo.email,
|
||||||
|
)
|
||||||
|
db.add(new_link)
|
||||||
|
if existing_user.avatar_url is None and userinfo.avatar_url:
|
||||||
|
existing_user.avatar_url = userinfo.avatar_url
|
||||||
|
plain_token, tokens = await _issue_refresh_token(existing_user, db)
|
||||||
|
await db.commit()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# Guard: if the email is already taken but we couldn't auto-link (e.g.
|
||||||
|
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
|
||||||
|
if not userinfo.email_verified:
|
||||||
|
conflict = await db.execute(select(User).where(User.email == userinfo.email))
|
||||||
|
if conflict.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_409_CONFLICT,
|
||||||
|
"An account with this email already exists. "
|
||||||
|
"Please sign in with your password.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. New user — social-only account (no password).
|
||||||
|
new_user = User(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
email=userinfo.email,
|
||||||
|
name=userinfo.name,
|
||||||
|
password_hash=None,
|
||||||
|
avatar_url=userinfo.avatar_url,
|
||||||
|
tier="free",
|
||||||
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
|
)
|
||||||
|
db.add(new_user)
|
||||||
|
await db.flush() # populate new_user.id
|
||||||
|
|
||||||
|
new_oauth = OAuthAccount(
|
||||||
|
user_id=new_user.id,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=userinfo.provider_user_id,
|
||||||
|
provider_email=userinfo.email,
|
||||||
|
)
|
||||||
|
db.add(new_oauth)
|
||||||
|
|
||||||
|
plain_token, tokens = await _issue_refresh_token(new_user, db)
|
||||||
|
await db.commit()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
# ── Onboarding helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
|
||||||
|
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
|
||||||
|
|
||||||
|
# We can't call the FastAPI dependency directly, but we can replicate
|
||||||
|
# the core logic inline. Instead, we just re-query the same way.
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
user_result = await db.execute(
|
||||||
|
select(
|
||||||
|
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||||
|
User.password_hash,
|
||||||
|
).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
onboarding_ms: int | None = None
|
||||||
|
if user_row and user_row.onboarding_completed_at is not None:
|
||||||
|
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||||
|
|
||||||
|
memory_dict: dict[str, str] = {}
|
||||||
|
try:
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
blocks = await mw.list_core_blocks(user_id)
|
||||||
|
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user_id,
|
||||||
|
email=email,
|
||||||
|
name=user_row.name if user_row else None,
|
||||||
|
surname=user_row.surname if user_row else None,
|
||||||
|
avatar_url=user_row.avatar_url if user_row else None,
|
||||||
|
has_password=bool(user_row.password_hash) if user_row else False,
|
||||||
|
tier=tier,
|
||||||
|
onboarding_completed_at=onboarding_ms,
|
||||||
|
memory=memory_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Onboarding routes ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateMemoryRequest(BaseModel):
|
||||||
|
memory: dict[str, str] = Field(default_factory=dict)
|
||||||
|
mark_onboarded: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/memory", response_model=UserProfile)
|
||||||
|
async def update_memory(
|
||||||
|
body: _UpdateMemoryRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update core memory key/value pairs and optionally mark onboarding complete."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
for key, value in body.memory.items():
|
||||||
|
await mw.update_core(current_user.id, key, value)
|
||||||
|
if body.mark_onboarded:
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.onboarding_completed_at = datetime.now(timezone.utc)
|
||||||
|
await db.commit()
|
||||||
|
return await _build_profile(current_user.id, current_user.email, db)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/me/onboarding/reset")
|
||||||
|
async def reset_onboarding(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
):
|
||||||
|
"""Reset onboarding so the wizard runs again on next login."""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.onboarding_completed_at = None
|
||||||
|
await db.commit()
|
||||||
|
return {"status": "reset"}
|
||||||
|
|
||||||
|
|
||||||
|
class _NormalizeRequest(BaseModel):
|
||||||
|
inputs: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class _NormalizeResponse(BaseModel):
|
||||||
|
normalized: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
|
||||||
|
async def normalize_onboarding(
|
||||||
|
body: _NormalizeRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _NormalizeResponse:
|
||||||
|
"""One-shot LLM normalization for free-text onboarding answers."""
|
||||||
|
if not body.inputs:
|
||||||
|
return _NormalizeResponse(normalized={})
|
||||||
|
try:
|
||||||
|
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||||
|
prompt = (
|
||||||
|
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
|
||||||
|
"Return a JSON object with the same keys and normalized values.\n"
|
||||||
|
"Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n"
|
||||||
|
f"Input: {json.dumps(body.inputs)}"
|
||||||
|
)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
normalized = json.loads(response.content)
|
||||||
|
return _NormalizeResponse(normalized=normalized)
|
||||||
|
except Exception:
|
||||||
|
# LLM failure must never block onboarding — return inputs unchanged
|
||||||
|
return _NormalizeResponse(normalized=body.inputs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Password management ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _ChangePasswordRequest(BaseModel):
|
||||||
|
current_password: str = Field(min_length=1)
|
||||||
|
new_password: str = Field(min_length=8)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
||||||
|
async def change_password(
|
||||||
|
body: _ChangePasswordRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Change the authenticated user's password.
|
||||||
|
|
||||||
|
Requires the current password for verification.
|
||||||
|
Returns 400 for social-only users (no password set).
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
if user.password_hash is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
"This account uses social login and has no password to change",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _verify_password(body.current_password, user.password_hash):
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
||||||
|
|
||||||
|
user.password_hash = _hash_password(body.new_password)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth account management ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/oauth-accounts", response_model=list[dict])
|
||||||
|
async def list_oauth_accounts(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""List all OAuth providers linked to the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
accounts = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"provider": a.provider,
|
||||||
|
"provider_email": a.provider_email,
|
||||||
|
"created_at": int(a.created_at.timestamp() * 1000),
|
||||||
|
}
|
||||||
|
for a in accounts
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
||||||
|
async def unlink_oauth_account(
|
||||||
|
provider: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Unlink an OAuth provider from the authenticated user.
|
||||||
|
|
||||||
|
Refuses if the user has no password and this is their only login method.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
oauth_result = await db.execute(
|
||||||
|
select(OAuthAccount).where(
|
||||||
|
OAuthAccount.user_id == current_user.id,
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
account = oauth_result.scalar_one_or_none()
|
||||||
|
if account is None:
|
||||||
|
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
||||||
|
|
||||||
|
# Safety: don't let users lock themselves out.
|
||||||
|
all_oauth = await db.execute(
|
||||||
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
oauth_count = len(all_oauth.scalars().all())
|
||||||
|
|
||||||
|
if user.password_hash is None and oauth_count <= 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
"Cannot unlink the only login method. Set a password first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.delete(account)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Avatar update ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateAvatarRequest(BaseModel):
|
||||||
|
avatar_url: str = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/avatar", response_model=UserProfile)
|
||||||
|
async def update_avatar(
|
||||||
|
body: _UpdateAvatarRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update the authenticated user's avatar URL.
|
||||||
|
|
||||||
|
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
||||||
|
to its own storage and passes the resulting URL here.
|
||||||
|
"""
|
||||||
|
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
||||||
|
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.avatar_url = body.avatar_url
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return await _build_profile(current_user.id, current_user.email, db)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Account deletion ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me", status_code=status.HTTP_200_OK)
|
||||||
|
async def delete_account(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Permanently delete the authenticated user's account.
|
||||||
|
|
||||||
|
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
||||||
|
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
||||||
|
is cancelled if active.
|
||||||
|
"""
|
||||||
|
# Cancel Stripe subscription if present.
|
||||||
|
try:
|
||||||
|
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||||
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
|
except HTTPException:
|
||||||
|
pass # No subscription — that's fine
|
||||||
|
|
||||||
|
# Delete all memory rows (core, associative, episodic, proactive).
|
||||||
|
try:
|
||||||
|
from app.models import ( # noqa: PLC0415
|
||||||
|
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
||||||
|
)
|
||||||
|
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
||||||
|
await db.execute(
|
||||||
|
model.__table__.delete().where(model.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Non-critical — cascade on User will handle most
|
||||||
|
|
||||||
|
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
await db.delete(user)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|||||||
@@ -1,171 +0,0 @@
|
|||||||
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
|
||||||
|
|
||||||
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
|
|
||||||
PostgreSQL ``backup_metadata`` table.
|
|
||||||
|
|
||||||
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
|
||||||
treating "history" as a ``{backup_id}`` path parameter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from email.utils import parsedate_to_datetime
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.billing.tier_manager import tier_manager
|
|
||||||
from app.db import get_session
|
|
||||||
from app.models import BackupMetadata as BackupMetadataModel
|
|
||||||
from app.schemas import BackupMetadata, UserProfile
|
|
||||||
from app.storage.blob_store import BlobStore
|
|
||||||
from app.storage.encryption import reject_if_tampered
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/backup", tags=["backup"])
|
|
||||||
|
|
||||||
_blob_store = BlobStore()
|
|
||||||
|
|
||||||
|
|
||||||
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
|
|
||||||
"""Return total backup bytes stored by *user_id*."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
|
|
||||||
BackupMetadataModel.user_id == user_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return int(result.scalar_one())
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_backup_quota(
|
|
||||||
user: UserProfile, size_bytes: int, db: AsyncSession
|
|
||||||
) -> None:
|
|
||||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
|
||||||
current = await _current_backup_bytes(user.id, db)
|
|
||||||
tier_manager.enforce_backup_quota(
|
|
||||||
user.tier, current_bytes=current, additional_bytes=size_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("")
|
|
||||||
async def upload_backup(
|
|
||||||
request: Request,
|
|
||||||
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
|
||||||
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
|
||||||
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Upload an E2E-encrypted backup blob.
|
|
||||||
|
|
||||||
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
|
||||||
"""
|
|
||||||
blob = await request.body()
|
|
||||||
reject_if_tampered(blob, x_backup_checksum)
|
|
||||||
await _check_backup_quota(current_user, len(blob), db)
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
|
||||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
|
||||||
)
|
|
||||||
|
|
||||||
row = BackupMetadataModel(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=current_user.id,
|
|
||||||
s3_key=s3_key,
|
|
||||||
version=x_backup_version,
|
|
||||||
timestamp=x_backup_timestamp,
|
|
||||||
checksum=x_backup_checksum,
|
|
||||||
size_bytes=len(blob),
|
|
||||||
)
|
|
||||||
db.add(row)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/history", response_model=list[BackupMetadata])
|
|
||||||
async def backup_history(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[BackupMetadata]:
|
|
||||||
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(BackupMetadataModel)
|
|
||||||
.where(BackupMetadataModel.user_id == current_user.id)
|
|
||||||
.order_by(BackupMetadataModel.timestamp.desc())
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
return [
|
|
||||||
BackupMetadata(
|
|
||||||
version=r.version,
|
|
||||||
timestamp=r.timestamp,
|
|
||||||
checksum=r.checksum,
|
|
||||||
chunk_count=1,
|
|
||||||
)
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
|
||||||
async def download_backup(
|
|
||||||
request: Request,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> Response:
|
|
||||||
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(BackupMetadataModel)
|
|
||||||
.where(BackupMetadataModel.user_id == current_user.id)
|
|
||||||
.order_by(BackupMetadataModel.timestamp.desc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
latest = result.scalar_one_or_none()
|
|
||||||
if latest is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
|
||||||
|
|
||||||
ims_header = request.headers.get("If-Modified-Since")
|
|
||||||
if ims_header:
|
|
||||||
try:
|
|
||||||
ims_dt = parsedate_to_datetime(ims_header)
|
|
||||||
ims_ms = int(ims_dt.timestamp() * 1000)
|
|
||||||
if latest.timestamp <= ims_ms:
|
|
||||||
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
|
||||||
except Exception:
|
|
||||||
pass # malformed header — ignore and serve the blob
|
|
||||||
|
|
||||||
blob = await _blob_store.download(current_user.id, latest.s3_key)
|
|
||||||
return Response(
|
|
||||||
content=blob,
|
|
||||||
media_type="application/octet-stream",
|
|
||||||
headers={
|
|
||||||
"X-Backup-Version": str(latest.version),
|
|
||||||
"X-Backup-Timestamp": str(latest.timestamp),
|
|
||||||
"X-Checksum": latest.checksum,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{backup_id}", response_model=dict)
|
|
||||||
async def delete_backup(
|
|
||||||
backup_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a specific backup by ID."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(BackupMetadataModel).where(
|
|
||||||
BackupMetadataModel.id == backup_id,
|
|
||||||
BackupMetadataModel.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
target = result.scalar_one_or_none()
|
|
||||||
if target is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
|
||||||
|
|
||||||
await _blob_store.delete(current_user.id, target.s3_key)
|
|
||||||
await db.delete(target)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -83,3 +83,16 @@ async def cancel_subscription(
|
|||||||
"""Cancel the active subscription."""
|
"""Cancel the active subscription."""
|
||||||
await stripe_service.cancel_subscription(current_user.id, db)
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/invoices", response_model=list[dict])
|
||||||
|
async def list_invoices(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return billing history (invoices) from Stripe.
|
||||||
|
|
||||||
|
Returns an empty list when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||||
|
return invoices
|
||||||
|
|||||||
@@ -1,29 +1,116 @@
|
|||||||
"""Chat routes: POST /chat (REST fallback).
|
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
|
||||||
|
|
||||||
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
import uuid
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.core.orchestrator import orchestrate
|
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||||
|
from app.core.deep_agent import run_home
|
||||||
|
from app.core.llm import embed
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import async_session
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.schemas import ChatRequest, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Embed helpers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedResponse(BaseModel):
|
||||||
|
vector: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def chat(
|
async def chat(
|
||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Route a chat message through the orchestrator.
|
"""REST fallback for home chat when websocket streaming is unavailable."""
|
||||||
|
response = await run_home(
|
||||||
|
user_id=current_user.id,
|
||||||
|
message=body.message,
|
||||||
|
context=body.context.model_dump(),
|
||||||
|
)
|
||||||
|
return JSONResponse(content={"response": response})
|
||||||
|
|
||||||
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
|
||||||
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
class _BriefRequest(BaseModel):
|
||||||
|
mode: Literal["home", "project"]
|
||||||
|
project_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _BriefResponse(BaseModel):
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/brief", response_model=_BriefResponse)
|
||||||
|
async def brief(
|
||||||
|
body: _BriefRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _BriefResponse:
|
||||||
|
"""REST fallback for brief when the device WebSocket is not ready."""
|
||||||
|
if body.mode == "project":
|
||||||
|
if not body.project_id:
|
||||||
|
raise HTTPException(status_code=422, detail="project_id required for project mode")
|
||||||
|
try:
|
||||||
|
uuid.UUID(body.project_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=422, detail="project_id must be a valid UUID")
|
||||||
|
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
current_user.id,
|
||||||
|
"",
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"_debug": {"request_id": request_id, "user_id": current_user.id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks: list[str] = []
|
||||||
|
if body.mode == "project":
|
||||||
|
stream = run_project_brief(current_user.id, body.project_id, context) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
stream = run_home_brief(current_user.id, context)
|
||||||
|
|
||||||
|
async for event_type, data in stream:
|
||||||
|
if event_type == "token" and data:
|
||||||
|
chunks.append(str(data))
|
||||||
|
|
||||||
|
return _BriefResponse(response="".join(chunks))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/embed", response_model=_EmbedResponse)
|
||||||
|
async def embed_text(
|
||||||
|
body: _EmbedRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _EmbedResponse:
|
||||||
|
"""Generate a 1536-dim embedding vector for the given text.
|
||||||
|
|
||||||
|
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||||
|
Used by Electron (vectordb.ts) for local note search.
|
||||||
"""
|
"""
|
||||||
result = await orchestrate(body)
|
vector = await embed(body.text)
|
||||||
return JSONResponse(content=result.model_dump())
|
return _EmbedResponse(vector=vector)
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ Protocol:
|
|||||||
4. Session enters message dispatch loop + heartbeat.
|
4. Session enters message dispatch loop + heartbeat.
|
||||||
|
|
||||||
Incoming frame dispatch:
|
Incoming frame dispatch:
|
||||||
- ``tool_result`` → resolves a pending tool-call Future.
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
- ``agent_data`` → enqueued in the per-run agent data queue.
|
- ``journey_start`` → starts a guided setup journey session.
|
||||||
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
- ``journey_message`` → continues a journey conversation.
|
||||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
- unknown types → logged, ignored.
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
||||||
|
|
||||||
@@ -39,16 +39,18 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
|
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||||
|
from app.core.deep_agent import run_floating_stream, run_home_stream
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.orchestrator import orchestrate_v3_stream
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
from app.schemas import WsFrameType
|
from app.schemas import WsFrameType, WsStreamEnd
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -147,37 +149,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: tool_result missing id from user=%s", user_id
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_data:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
await queue.put(frame)
|
|
||||||
except RuntimeError:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data for unknown run user=%s run=%s",
|
|
||||||
user_id,
|
|
||||||
run_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_complete:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
# Sentinel: signals the agent data stream is finished.
|
|
||||||
await queue.put(None)
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_complete missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.home_request:
|
elif frame_type == WsFrameType.home_request:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_home_request(websocket, user_id, frame)
|
_handle_home_request(websocket, user_id, frame)
|
||||||
@@ -188,6 +159,21 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_floating_request(websocket, user_id, frame)
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.brief_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_brief_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_start:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_start(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_message:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_message(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -219,33 +205,37 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
agent_holder: list = []
|
|
||||||
try:
|
try:
|
||||||
token_stream = orchestrate_v3_stream(
|
event_stream = run_home_stream(user_id, message, context)
|
||||||
user_id, message, context, agent_holder=agent_holder
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
)
|
async for ws_frame in formatter.format(event_stream):
|
||||||
formatter = HomeFormatter(request_id=request_id, tool_results=[])
|
|
||||||
async for ws_frame in formatter.format(token_stream):
|
|
||||||
# Inject mutations from agent tool_results into stream_end
|
|
||||||
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
|
|
||||||
ws_frame.mutations = [ # type: ignore[union-attr]
|
|
||||||
{"action": r["action"], "table": r["table"], "data": r["data"]}
|
|
||||||
for r in getattr(agent_holder[0], "tool_results", [])
|
|
||||||
]
|
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
# Collect text chunks to build the full response for episode storage
|
# Collect text chunks to build the full response for episode storage
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
@@ -262,8 +252,15 @@ async def _handle_home_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_floating_request(
|
async def _handle_floating_request(
|
||||||
@@ -276,29 +273,38 @@ async def _handle_floating_request(
|
|||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
scope: dict = frame.get("scope", {})
|
scope: dict = frame.get("scope", {})
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
json.dumps(scope, ensure_ascii=True)[:200],
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {"scope": scope, **memory_context}
|
context: dict = {
|
||||||
|
"scope": scope,
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
agent_holder: list = []
|
|
||||||
try:
|
try:
|
||||||
token_stream = orchestrate_v3_stream(
|
event_stream = run_floating_stream(user_id, message, context)
|
||||||
user_id, message, context, agent_holder=agent_holder
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
)
|
async for ws_frame in formatter.format(event_stream):
|
||||||
formatter = FloatingFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(token_stream):
|
|
||||||
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
|
|
||||||
ws_frame.mutations = [ # type: ignore[union-attr]
|
|
||||||
{"action": r["action"], "table": r["table"], "data": r["data"]}
|
|
||||||
for r in getattr(agent_holder[0], "tool_results", [])
|
|
||||||
]
|
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
@@ -314,8 +320,152 @@ async def _handle_floating_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_brief_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a brief_request frame — streams plain-text brief back on the socket.
|
||||||
|
|
||||||
|
No episode storage — briefs are not conversations.
|
||||||
|
"""
|
||||||
|
import uuid as _uuid
|
||||||
|
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
session_id = frame.get("session_id") or str(uuid4())
|
||||||
|
mode: str = frame.get("mode", "home")
|
||||||
|
project_id: str | None = frame.get("project_id")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
|
||||||
|
user_id, request_id, mode, project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate project_id for project mode before touching LLM.
|
||||||
|
if mode == "project":
|
||||||
|
try:
|
||||||
|
if not project_id:
|
||||||
|
raise ValueError("project_id required for project mode")
|
||||||
|
_uuid.UUID(project_id)
|
||||||
|
except (ValueError, AttributeError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Enrich context with memory (no user message — use empty string as probe).
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
"",
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
if mode == "project":
|
||||||
|
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
event_stream = run_home_brief(user_id, context)
|
||||||
|
|
||||||
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: brief_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"device_ws: brief_request_end user=%s req=%s mode=%s",
|
||||||
|
user_id, request_id, mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_start(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_start frame — explores directory and sends first question."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_start(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": frame.get("session_id", ""),
|
||||||
|
"message": f"Failed to start journey: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_message(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_message frame — continues the journey conversation."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_message(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_message failed user=%s session=%s: %s",
|
||||||
|
user_id, session_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Journey error: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
@@ -351,6 +501,3 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
|||||||
user_id,
|
user_id,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
225
app/api/routes/memory.py
Normal file
225
app/api/routes/memory.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
"""Memory management routes — view/edit/delete user memory tiers.
|
||||||
|
|
||||||
|
All routes require authentication. Data is always user-scoped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import (
|
||||||
|
ExtractionQueue,
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
MemoryRelation,
|
||||||
|
)
|
||||||
|
from app.schemas import UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ALLOWED_PREDICATES = {
|
||||||
|
"works_at",
|
||||||
|
"reports_to",
|
||||||
|
"stakeholder_of",
|
||||||
|
"last_contacted_on",
|
||||||
|
"owes_followup",
|
||||||
|
"manages",
|
||||||
|
"collaborates_with",
|
||||||
|
"owns",
|
||||||
|
"member_of",
|
||||||
|
"custom",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class RelationOut(BaseModel):
|
||||||
|
id: str
|
||||||
|
subject_label: str
|
||||||
|
subject_type: str
|
||||||
|
predicate: str
|
||||||
|
object_label: str
|
||||||
|
object_type: str
|
||||||
|
confidence: float
|
||||||
|
last_confirmed_at: int | None = None # epoch ms
|
||||||
|
|
||||||
|
|
||||||
|
class RelationPatch(BaseModel):
|
||||||
|
subject_label: str | None = None
|
||||||
|
object_label: str | None = None
|
||||||
|
predicate: str | None = None
|
||||||
|
confidence: float | None = Field(None, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class CoreAddBody(BaseModel):
|
||||||
|
key: str = Field(..., min_length=1, max_length=255)
|
||||||
|
value: str = Field(..., min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _relation_to_out(row: MemoryRelation) -> RelationOut:
|
||||||
|
last_ms: int | None = None
|
||||||
|
if row.last_confirmed_at is not None:
|
||||||
|
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
|
||||||
|
return RelationOut(
|
||||||
|
id=row.id,
|
||||||
|
subject_label=row.subject_label,
|
||||||
|
subject_type=row.subject_type,
|
||||||
|
predicate=row.predicate,
|
||||||
|
object_label=row.object_label,
|
||||||
|
object_type=row.object_type,
|
||||||
|
confidence=row.confidence,
|
||||||
|
last_confirmed_at=last_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/core", response_model=dict[str, str])
|
||||||
|
async def get_core_memory(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Return all core memory k/v pairs (plaintext) for the current user."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
blocks = await mw.list_core_blocks(current_user.id)
|
||||||
|
return {b["label"]: b["value"] for b in blocks}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_core_key(
|
||||||
|
key: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""Delete a single core memory key (GDPR Art. 17)."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
deleted = await mw.delete_core(current_user.id, key)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
|
||||||
|
async def add_core_key(
|
||||||
|
body: CoreAddBody,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Add or overwrite a core memory key/value pair."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
await mw.update_core(current_user.id, body.key, body.value)
|
||||||
|
return {body.key: body.value}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/relational", response_model=list[RelationOut])
|
||||||
|
async def get_relational_memory(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[RelationOut]:
|
||||||
|
"""Return all relational memory rows for the current user."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
rows = await mw.query_relations(current_user.id, limit=200)
|
||||||
|
return [_relation_to_out(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/relational/{relation_id}", response_model=RelationOut)
|
||||||
|
async def patch_relation(
|
||||||
|
relation_id: str,
|
||||||
|
body: RelationPatch,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> RelationOut:
|
||||||
|
"""Edit a relation row's labels, predicate, or confidence."""
|
||||||
|
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(
|
||||||
|
MemoryRelation.id == relation_id,
|
||||||
|
MemoryRelation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||||
|
|
||||||
|
if body.subject_label is not None:
|
||||||
|
row.subject_label = body.subject_label
|
||||||
|
if body.object_label is not None:
|
||||||
|
row.object_label = body.object_label
|
||||||
|
if body.predicate is not None:
|
||||||
|
row.predicate = body.predicate
|
||||||
|
if body.confidence is not None:
|
||||||
|
row.confidence = body.confidence
|
||||||
|
row.last_confirmed_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(row)
|
||||||
|
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
|
||||||
|
return _relation_to_out(row)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_relation(
|
||||||
|
relation_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""Hard-delete a relation row (GDPR Art. 17)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(
|
||||||
|
MemoryRelation.id == relation_id,
|
||||||
|
MemoryRelation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||||
|
await db.delete(row)
|
||||||
|
await db.commit()
|
||||||
|
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def forget_all(
|
||||||
|
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""Wipe all memory tiers for the current user (GDPR Art. 17).
|
||||||
|
|
||||||
|
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
|
||||||
|
"""
|
||||||
|
if x_confirm != "true":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
|
||||||
|
)
|
||||||
|
|
||||||
|
uid = current_user.id
|
||||||
|
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
|
||||||
|
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
|
||||||
|
await db.commit()
|
||||||
|
logger.warning("memory: forget_all GDPR wipe user=%s", uid)
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.core.execution_plan import plan_cache
|
|
||||||
from app.schemas import ExecutionPlan, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/plans", tags=["plans"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook", response_model=list[ExecutionPlan])
|
|
||||||
async def list_playbooks(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached execution plan playbooks for the authenticated user.
|
|
||||||
|
|
||||||
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
|
||||||
"""
|
|
||||||
return plan_cache.get_all_playbooks()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
|
||||||
async def get_playbook(
|
|
||||||
plan_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> ExecutionPlan:
|
|
||||||
"""Return a specific execution plan playbook by ID."""
|
|
||||||
plan = plan_cache.get_plan(plan_id)
|
|
||||||
if plan is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Plan not found: {plan_id}",
|
|
||||||
)
|
|
||||||
return plan
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
"""Plugins routes: browse and install plugins from the marketplace.
|
|
||||||
|
|
||||||
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
|
|
||||||
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.db import get_session
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
from app.marketplace.revenue_share import revenue_share
|
|
||||||
from app.models import PluginInstallation, PluginReview as PluginReviewModel
|
|
||||||
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier gate ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _require_plugin_tier(user: UserProfile) -> None:
|
|
||||||
"""Raise HTTP 403 for users below Power tier."""
|
|
||||||
if user.tier not in ("power", "team"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Plugin marketplace requires Power tier or above",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local detail schema ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _PluginDetail(BaseModel):
|
|
||||||
plugin: PluginManifest
|
|
||||||
install_count: int
|
|
||||||
ratings: list[Any]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("", response_model=PluginListResponse)
|
|
||||||
async def list_plugins(
|
|
||||||
category: str | None = Query(default=None),
|
|
||||||
q: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> PluginListResponse:
|
|
||||||
"""Browse the plugin marketplace. Requires Power tier or above."""
|
|
||||||
_require_plugin_tier(current_user)
|
|
||||||
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
|
||||||
async def get_plugin(
|
|
||||||
plugin_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> _PluginDetail:
|
|
||||||
"""Get full plugin details including install count. Requires Power tier or above."""
|
|
||||||
_require_plugin_tier(current_user)
|
|
||||||
entry = await registry.get_plugin(db, plugin_id)
|
|
||||||
if entry is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
|
||||||
|
|
||||||
# Fetch review ratings for this plugin
|
|
||||||
review_result = await db.execute(
|
|
||||||
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
|
|
||||||
)
|
|
||||||
reviews = review_result.scalars().all()
|
|
||||||
ratings = [
|
|
||||||
{
|
|
||||||
"reviewer_id": r.reviewer_id,
|
|
||||||
"decision": r.decision,
|
|
||||||
"notes": r.notes,
|
|
||||||
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
|
|
||||||
}
|
|
||||||
for r in reviews
|
|
||||||
]
|
|
||||||
|
|
||||||
return _PluginDetail(
|
|
||||||
plugin=entry["manifest"],
|
|
||||||
install_count=entry["install_count"],
|
|
||||||
ratings=ratings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{plugin_id}/install", response_model=dict)
|
|
||||||
async def install_plugin(
|
|
||||||
plugin_id: str,
|
|
||||||
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
|
||||||
|
|
||||||
Requires Power tier or above.
|
|
||||||
"""
|
|
||||||
_require_plugin_tier(current_user)
|
|
||||||
entry = await registry.get_plugin(db, plugin_id)
|
|
||||||
if entry is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
|
||||||
|
|
||||||
# Record the installation in plugin_installations
|
|
||||||
installation = PluginInstallation(
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
)
|
|
||||||
db.add(installation)
|
|
||||||
await db.flush()
|
|
||||||
|
|
||||||
await revenue_share.record_install(
|
|
||||||
db,
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
amount_cents=entry["manifest"].price_cents,
|
|
||||||
)
|
|
||||||
|
|
||||||
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
|
||||||
return {"ok": True, "download_url": download_url}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{plugin_id}/install", response_model=dict)
|
|
||||||
async def uninstall_plugin(
|
|
||||||
plugin_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Unregister a plugin installation."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(PluginInstallation).where(
|
|
||||||
PluginInstallation.plugin_id == plugin_id,
|
|
||||||
PluginInstallation.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
installation = result.scalar_one_or_none()
|
|
||||||
if installation is not None:
|
|
||||||
await db.delete(installation)
|
|
||||||
await db.commit()
|
|
||||||
await registry.record_uninstall(db, plugin_id)
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -1,195 +0,0 @@
|
|||||||
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
|
||||||
|
|
||||||
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
|
|
||||||
PostgreSQL ``storage_records`` table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.billing.tier_manager import tier_manager
|
|
||||||
from app.db import get_session
|
|
||||||
from app.models import StorageRecord
|
|
||||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
|
||||||
from app.storage.blob_store import BlobStore
|
|
||||||
from app.storage.encryption import reject_if_tampered
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/storage", tags=["storage"])
|
|
||||||
|
|
||||||
_blob_store = BlobStore()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local response schemas ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _CreateResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
created_at: int
|
|
||||||
|
|
||||||
|
|
||||||
class _RecordMeta(BaseModel):
|
|
||||||
id: str
|
|
||||||
table: str
|
|
||||||
checksum: str
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
|
|
||||||
"""Return total bytes stored by *user_id*."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
|
|
||||||
StorageRecord.user_id == user_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return int(result.scalar_one())
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
|
|
||||||
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
|
|
||||||
current = await _current_usage_bytes(user.id, db)
|
|
||||||
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_record_for_user(
|
|
||||||
record_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> StorageRecord:
|
|
||||||
"""Look up a record and verify ownership. Returns 404 on mismatch
|
|
||||||
to prevent user enumeration attacks."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(StorageRecord).where(
|
|
||||||
StorageRecord.id == record_id, StorageRecord.user_id == user_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_record(
|
|
||||||
body: StorageRecordCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> _CreateResponse:
|
|
||||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
|
||||||
reject_if_tampered(body.blob, body.checksum)
|
|
||||||
await _check_quota(current_user, len(body.blob), db)
|
|
||||||
|
|
||||||
record_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
|
||||||
current_user.id, body.table, record_id, body.blob, body.checksum
|
|
||||||
)
|
|
||||||
|
|
||||||
record = StorageRecord(
|
|
||||||
id=record_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
table_name=body.table,
|
|
||||||
s3_key=s3_key,
|
|
||||||
checksum=body.checksum,
|
|
||||||
size_bytes=len(body.blob),
|
|
||||||
)
|
|
||||||
db.add(record)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(record)
|
|
||||||
|
|
||||||
created_at_ms = int(record.created_at.timestamp() * 1000)
|
|
||||||
return _CreateResponse(id=record_id, created_at=created_at_ms)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/records", response_model=list[_RecordMeta])
|
|
||||||
async def list_records(
|
|
||||||
table: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
limit: int = Query(default=50, ge=1, le=200),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[_RecordMeta]:
|
|
||||||
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
|
||||||
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
|
|
||||||
if table is not None:
|
|
||||||
query = query.where(StorageRecord.table_name == table)
|
|
||||||
query = query.offset((page - 1) * limit).limit(limit)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
|
|
||||||
return [
|
|
||||||
_RecordMeta(
|
|
||||||
id=r.id,
|
|
||||||
table=r.table_name,
|
|
||||||
checksum=r.checksum,
|
|
||||||
created_at=int(r.created_at.timestamp() * 1000),
|
|
||||||
updated_at=int(r.updated_at.timestamp() * 1000),
|
|
||||||
)
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/records/{record_id}")
|
|
||||||
async def download_record(
|
|
||||||
record_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> Response:
|
|
||||||
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
|
||||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
|
||||||
blob = await _blob_store.download(current_user.id, record.s3_key)
|
|
||||||
return Response(
|
|
||||||
content=blob,
|
|
||||||
media_type="application/octet-stream",
|
|
||||||
headers={"X-Checksum": record.checksum},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/records/{record_id}", response_model=dict)
|
|
||||||
async def update_record(
|
|
||||||
record_id: str,
|
|
||||||
body: StorageRecordUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
|
||||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
|
||||||
reject_if_tampered(body.blob, body.checksum)
|
|
||||||
|
|
||||||
delta = len(body.blob) - record.size_bytes
|
|
||||||
if delta > 0:
|
|
||||||
await _check_quota(current_user, delta, db)
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
|
||||||
current_user.id, record.table_name, record_id, body.blob, body.checksum
|
|
||||||
)
|
|
||||||
|
|
||||||
record.s3_key = s3_key
|
|
||||||
record.checksum = body.checksum
|
|
||||||
record.size_bytes = len(body.blob)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/records/{record_id}", response_model=dict)
|
|
||||||
async def delete_record(
|
|
||||||
record_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a record and its S3 blob."""
|
|
||||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
|
||||||
await _blob_store.delete(current_user.id, record.s3_key)
|
|
||||||
await db.delete(record)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.core.llm import embed
|
|
||||||
from app.schemas import (
|
|
||||||
UserProfile,
|
|
||||||
VectorSearchRequest,
|
|
||||||
VectorSearchResponse,
|
|
||||||
VectorUpsertRequest,
|
|
||||||
)
|
|
||||||
from app.storage.encryption import reject_if_tampered
|
|
||||||
from app.storage.vector_store import VectorStore
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/storage", tags=["vectors"])
|
|
||||||
|
|
||||||
_vector_store = VectorStore()
|
|
||||||
|
|
||||||
|
|
||||||
class _VectorDeleteRequest(BaseModel):
|
|
||||||
ids: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class _EmbedRequest(BaseModel):
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class _EmbedResponse(BaseModel):
|
|
||||||
vector: list[float]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/vectors/upsert", response_model=dict)
|
|
||||||
async def upsert_vectors(
|
|
||||||
body: VectorUpsertRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, int]:
|
|
||||||
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
|
||||||
for item in body.vectors:
|
|
||||||
reject_if_tampered(item.blob, item.checksum)
|
|
||||||
await _vector_store.upsert(current_user.id, body.vectors)
|
|
||||||
return {"upserted": len(body.vectors)}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
|
||||||
async def search_vectors(
|
|
||||||
body: VectorSearchRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> VectorSearchResponse:
|
|
||||||
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
|
||||||
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
|
||||||
return VectorSearchResponse(results=results)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/vectors", response_model=dict)
|
|
||||||
async def delete_vectors(
|
|
||||||
body: _VectorDeleteRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete vectors by ID, scoped to the authenticated user."""
|
|
||||||
await _vector_store.delete(current_user.id, body.ids)
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/vectors/embed", response_model=_EmbedResponse)
|
|
||||||
async def embed_text(
|
|
||||||
body: _EmbedRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> _EmbedResponse:
|
|
||||||
"""Generate a 1536-dim embedding vector for the given text.
|
|
||||||
|
|
||||||
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
|
||||||
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
|
|
||||||
"""
|
|
||||||
vector = await embed(body.text)
|
|
||||||
return _EmbedResponse(vector=vector)
|
|
||||||
1
app/auth/__init__.py
Normal file
1
app/auth/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"OAuth provider abstractions and utilities."
|
||||||
135
app/auth/oauth_providers.py
Normal file
135
app/auth/oauth_providers.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""OAuth 2.0 + PKCE provider abstractions.
|
||||||
|
|
||||||
|
Each provider implements a three-step flow designed for a desktop (public) client:
|
||||||
|
|
||||||
|
1. get_authorization_url(state, code_challenge) → str
|
||||||
|
Build the provider's consent-screen URL. State and code_challenge are
|
||||||
|
generated server-side; the client opens this URL in the system browser.
|
||||||
|
|
||||||
|
2. exchange_code(code, code_verifier, redirect_uri) → dict
|
||||||
|
Exchange the short-lived authorization code for an access token.
|
||||||
|
The code_verifier proves ownership of the PKCE challenge.
|
||||||
|
|
||||||
|
3. get_userinfo(access_token) → OAuthUserInfo
|
||||||
|
Fetch the canonical user identity from the provider.
|
||||||
|
|
||||||
|
Currently supported providers:
|
||||||
|
- GoogleOAuthProvider (scope: openid email profile)
|
||||||
|
|
||||||
|
Adding a new provider:
|
||||||
|
- Implement the three methods above.
|
||||||
|
- Register in _PROVIDERS inside routes/auth.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import urllib.parse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
# ── Data transfer objects ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthUserInfo:
|
||||||
|
"""Normalized user identity returned by any provider."""
|
||||||
|
|
||||||
|
provider_user_id: str
|
||||||
|
email: str
|
||||||
|
email_verified: bool
|
||||||
|
avatar_url: str | None
|
||||||
|
name: str | None
|
||||||
|
|
||||||
|
|
||||||
|
# ── PKCE helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pkce_pair() -> tuple[str, str]:
|
||||||
|
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
|
||||||
|
|
||||||
|
The code_verifier is a random 32-byte URL-safe base64 string.
|
||||||
|
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
|
||||||
|
"""
|
||||||
|
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
|
||||||
|
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||||
|
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||||
|
return code_verifier, code_challenge
|
||||||
|
|
||||||
|
|
||||||
|
# ── Google provider ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleOAuthProvider:
|
||||||
|
"""Google OAuth 2.0 provider (openid email profile scope).
|
||||||
|
|
||||||
|
Uses Google's standard authorization endpoint with PKCE S256.
|
||||||
|
Does NOT use google-auth-oauthlib to keep the flow generic and async.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "google"
|
||||||
|
|
||||||
|
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||||
|
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||||
|
|
||||||
|
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.redirect_uri = redirect_uri
|
||||||
|
|
||||||
|
def get_authorization_url(self, state: str, code_challenge: str) -> str:
|
||||||
|
"""Build the Google consent-screen URL."""
|
||||||
|
params = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": "openid email profile",
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": code_challenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
"access_type": "offline",
|
||||||
|
"prompt": "select_account",
|
||||||
|
}
|
||||||
|
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
|
async def exchange_code(
|
||||||
|
self, code: str, code_verifier: str, redirect_uri: str
|
||||||
|
) -> dict:
|
||||||
|
"""Exchange authorization code for an access token."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self._TOKEN_URL,
|
||||||
|
data={
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"code": code,
|
||||||
|
"code_verifier": code_verifier,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
|
||||||
|
"""Fetch the authenticated user's identity from Google."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
self._USERINFO_URL,
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
return OAuthUserInfo(
|
||||||
|
provider_user_id=data["sub"],
|
||||||
|
email=data["email"],
|
||||||
|
email_verified=data.get("email_verified", False),
|
||||||
|
avatar_url=data.get("picture"),
|
||||||
|
name=data.get("name"),
|
||||||
|
)
|
||||||
@@ -43,8 +43,8 @@ class StripeService:
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tier: str,
|
tier: str,
|
||||||
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
cancel_url: str = "https://app.adiuvai.app/billing/cancel",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a Stripe checkout session and return the URL.
|
"""Create a Stripe checkout session and return the URL.
|
||||||
|
|
||||||
@@ -200,6 +200,45 @@ class StripeService:
|
|||||||
sub.status = "canceled"
|
sub.status = "canceled"
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
async def list_invoices(
|
||||||
|
self, user_id: str, db: AsyncSession, limit: int = 24
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return recent invoices for the user from Stripe.
|
||||||
|
|
||||||
|
Returns an empty list when Stripe is not configured or the user has
|
||||||
|
no ``stripe_customer_id``.
|
||||||
|
"""
|
||||||
|
if not self._configured():
|
||||||
|
return []
|
||||||
|
|
||||||
|
from app.models import User # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(User.stripe_customer_id).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
customer_id = result.scalar_one_or_none()
|
||||||
|
if not customer_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = self._client()
|
||||||
|
invoices = s.Invoice.list(customer=customer_id, limit=limit)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": inv.id,
|
||||||
|
"amount_due": inv.amount_due,
|
||||||
|
"amount_paid": inv.amount_paid,
|
||||||
|
"currency": inv.currency,
|
||||||
|
"status": inv.status,
|
||||||
|
"created": inv.created * 1000, # epoch ms
|
||||||
|
"invoice_url": inv.hosted_invoice_url,
|
||||||
|
"invoice_pdf": inv.invoice_pdf,
|
||||||
|
}
|
||||||
|
for inv in invoices.auto_paging_iter()
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
# ── Private DB helpers ───────────────────────────────────────────────
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
async def _upsert_subscription(
|
async def _upsert_subscription(
|
||||||
|
|||||||
@@ -21,42 +21,50 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"free": {
|
"free": {
|
||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
"cloud_storage_gb": 0,
|
"batch_runs_per_day": 5,
|
||||||
"backup_gb": 0,
|
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"plugin_marketplace": False,
|
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": False, # keyword fallback only
|
||||||
|
"realtime_extraction": False, # batch queue (Phase 2)
|
||||||
|
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||||
|
"proactive_mining": False, # Power+ only (Phase 5)
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
"cloud_storage_gb": 5,
|
"batch_runs_per_day": 50,
|
||||||
"backup_gb": 5,
|
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"plugin_marketplace": False,
|
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": True, # pgvector cosine search
|
||||||
|
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||||
|
"relational_memory": True, # person/project predicates
|
||||||
|
"proactive_mining": False, # Power+ only (Phase 5)
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1, # unlimited
|
"batch_active": -1, # unlimited
|
||||||
"cloud_storage_gb": 25,
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"backup_gb": 25,
|
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"plugin_marketplace": True,
|
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": True,
|
||||||
|
"realtime_extraction": True,
|
||||||
|
"relational_memory": True, # all predicates incl. custom
|
||||||
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
"cloud_storage_gb": -1, # unlimited
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"backup_gb": -1, # unlimited
|
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"plugin_marketplace": True,
|
|
||||||
"sso": True,
|
"sso": True,
|
||||||
|
"real_embeddings": True,
|
||||||
|
"realtime_extraction": True,
|
||||||
|
"relational_memory": True, # all predicates incl. custom
|
||||||
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,16 +85,18 @@ class TierManager:
|
|||||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for ``user_id`` from the DB.
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
Falls back to ``'free'`` when no subscription row exists.
|
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
||||||
|
when no subscription row exists.
|
||||||
"""
|
"""
|
||||||
from app.models import Subscription # noqa: PLC0415
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str | None = result.scalar_one_or_none()
|
tier: str | None = result.scalar_one_or_none()
|
||||||
if tier is None or tier not in FEATURES:
|
if tier is None or tier not in FEATURES:
|
||||||
return "free"
|
return "power" if settings.ENV == "dev" else "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
# ── Feature access ───────────────────────────────────────────────────
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
@@ -119,71 +129,6 @@ class TierManager:
|
|||||||
"""Return the requests-per-minute limit for ``tier``."""
|
"""Return the requests-per-minute limit for ``tier``."""
|
||||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
# ── Storage quota ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def enforce_quota(
|
|
||||||
self,
|
|
||||||
tier: BillingTier,
|
|
||||||
current_bytes: int = 0,
|
|
||||||
additional_bytes: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
|
||||||
|
|
||||||
``tier`` is the caller's current tier (from ``current_user.tier``).
|
|
||||||
``current_bytes`` is the total bytes already stored (queried by caller).
|
|
||||||
"""
|
|
||||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
|
||||||
if limit_gb == 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Cloud storage is not available on the '{tier}' tier",
|
|
||||||
)
|
|
||||||
if limit_gb == -1:
|
|
||||||
return # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024 ** 3
|
|
||||||
if current_bytes + additional_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
def enforce_backup_quota(
|
|
||||||
self,
|
|
||||||
tier: BillingTier,
|
|
||||||
current_bytes: int = 0,
|
|
||||||
additional_bytes: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
|
||||||
limit_gb: int = FEATURES[tier]["backup_gb"]
|
|
||||||
if limit_gb == 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Backup is not available on the '{tier}' tier",
|
|
||||||
)
|
|
||||||
if limit_gb == -1:
|
|
||||||
return # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024 ** 3
|
|
||||||
if current_bytes + additional_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_quota(
|
|
||||||
self,
|
|
||||||
tier: BillingTier,
|
|
||||||
current_bytes: int = 0,
|
|
||||||
additional_bytes: int = 0,
|
|
||||||
) -> bool:
|
|
||||||
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
|
||||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
|
||||||
if limit_gb == 0:
|
|
||||||
return False
|
|
||||||
if limit_gb == -1:
|
|
||||||
return True
|
|
||||||
limit_bytes = limit_gb * 1024 ** 3
|
|
||||||
return current_bytes + additional_bytes <= limit_bytes
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton shared across the app.
|
# Module-level singleton shared across the app.
|
||||||
tier_manager = TierManager()
|
tier_manager = TierManager()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
|
||||||
JWT_SECRET: str = "change-me-in-production"
|
JWT_SECRET: str = "change-me-in-production"
|
||||||
JWT_ALGORITHM: str = "HS256"
|
JWT_ALGORITHM: str = "HS256"
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
@@ -12,26 +12,26 @@ class Settings(BaseSettings):
|
|||||||
STRIPE_SECRET_KEY: str = ""
|
STRIPE_SECRET_KEY: str = ""
|
||||||
STRIPE_WEBHOOK_SECRET: str = ""
|
STRIPE_WEBHOOK_SECRET: str = ""
|
||||||
|
|
||||||
S3_BUCKET: str = ""
|
|
||||||
S3_REGION: str = "us-east-1"
|
|
||||||
S3_ENDPOINT_URL: str = ""
|
|
||||||
AWS_ACCESS_KEY_ID: str = ""
|
|
||||||
AWS_SECRET_ACCESS_KEY: str = ""
|
|
||||||
|
|
||||||
PINECONE_API_KEY: str = ""
|
|
||||||
PINECONE_INDEX: str = "adiuva"
|
|
||||||
QDRANT_URL: str = ""
|
|
||||||
QDRANT_API_KEY: str = ""
|
|
||||||
|
|
||||||
OPENAI_API_KEY: str = ""
|
OPENAI_API_KEY: str = ""
|
||||||
ANTHROPIC_API_KEY: str = ""
|
ANTHROPIC_API_KEY: str = ""
|
||||||
GOOGLE_API_KEY: str = ""
|
GOOGLE_API_KEY: str = ""
|
||||||
CEREBRAS_API_KEY: str = ""
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
|
||||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
||||||
|
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
|
||||||
|
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
||||||
|
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
|
||||||
|
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||||
|
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||||
|
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
|
||||||
|
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||||
|
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||||
|
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||||
|
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
|
||||||
|
|
||||||
# GitHub Copilot OAuth token storage directory.
|
# GitHub Copilot OAuth token storage directory.
|
||||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||||
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||||
@@ -45,16 +45,39 @@ class Settings(BaseSettings):
|
|||||||
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||||
MS_TENANT_ID: str = "common"
|
MS_TENANT_ID: str = "common"
|
||||||
|
|
||||||
|
# Google Login OAuth credentials — scope: openid email profile.
|
||||||
|
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
|
||||||
|
GOOGLE_AUTH_CLIENT_ID: str = ""
|
||||||
|
GOOGLE_AUTH_CLIENT_SECRET: str = ""
|
||||||
|
# The redirect URI registered in Google Cloud Console.
|
||||||
|
# Google redirects here after consent; this backend route then bounces to
|
||||||
|
# the adiuvai:// deep link so the Electron app receives the code.
|
||||||
|
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
|
||||||
|
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
||||||
|
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
||||||
|
|
||||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
OAUTH_ENCRYPTION_KEY: str = ""
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = [
|
||||||
|
"app://.",
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:5173",
|
||||||
|
"http://localhost:4173", # Vite preview (web SPA)
|
||||||
|
"https://app.adiuvai.com", # Production web portal
|
||||||
|
]
|
||||||
|
|
||||||
|
LANGFUSE_SECRET_KEY: str = ""
|
||||||
|
LANGFUSE_PUBLIC_KEY: str = ""
|
||||||
|
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
||||||
|
|
||||||
|
SCHEDULER_ENABLED: bool = True
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
"""Minimal agent base types retained for compatibility with batch runners."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
class BaseAgent(ABC):
|
||||||
"""Common base for all agents."""
|
"""Common base for non-chat agents still using the old base contract."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -28,190 +27,4 @@ class BaseAgent(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def skills(self) -> list[str]:
|
def skills(self) -> list[str]:
|
||||||
"""Override in subclasses to advertise capabilities."""
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(BaseAgent):
|
|
||||||
"""Base class for LLM-powered chat agents."""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
|
|
||||||
self.tool_results: list[dict] = []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
"""Process a user query and return a text response."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def handle_stream(
|
|
||||||
self, query: str, context: dict[str, Any]
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming variant of handle().
|
|
||||||
|
|
||||||
Default: calls handle() and yields the full response as one chunk.
|
|
||||||
Override in subclasses for true token-level streaming via _tool_loop_stream.
|
|
||||||
"""
|
|
||||||
yield await self.handle(query, context)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
"""Return LangChain tool definitions available to this agent."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def _tool_loop(
|
|
||||||
self,
|
|
||||||
llm: Any,
|
|
||||||
messages: list[Any],
|
|
||||||
tools: list[Any],
|
|
||||||
max_iter: int = 5,
|
|
||||||
) -> str:
|
|
||||||
"""Shared tool-calling loop.
|
|
||||||
|
|
||||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
|
||||||
requesting tool calls or *max_iter* is reached, and returns the
|
|
||||||
final text response. Captures raw execute_on_client results in
|
|
||||||
``self.tool_results``.
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
||||||
|
|
||||||
collector: list[dict] = []
|
|
||||||
set_tool_result_collector(collector)
|
|
||||||
try:
|
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
||||||
|
|
||||||
for _ in range(max_iter):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
return str(response.content)
|
|
||||||
|
|
||||||
# Execute each requested tool call
|
|
||||||
tool_map = {t.name: t for t in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
|
||||||
if tool_fn is None:
|
|
||||||
result = f"Unknown tool: {call['name']}"
|
|
||||||
else:
|
|
||||||
result = await tool_fn.ainvoke(call["args"])
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exhausted iterations — ask model for a final answer without tools
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
return str(response.content)
|
|
||||||
finally:
|
|
||||||
clear_tool_result_collector()
|
|
||||||
self.tool_results = collector
|
|
||||||
|
|
||||||
async def _tool_loop_stream(
|
|
||||||
self,
|
|
||||||
llm: Any,
|
|
||||||
messages: list[Any],
|
|
||||||
tools: list[Any],
|
|
||||||
max_iter: int = 5,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming variant of ``_tool_loop``.
|
|
||||||
|
|
||||||
Behaves identically for tool-calling iterations (uses ainvoke to parse
|
|
||||||
tool calls). For the final response — when the model produces no further
|
|
||||||
tool calls — switches to ``llm.astream()`` and yields text tokens.
|
|
||||||
Captures raw execute_on_client results in ``self.tool_results``.
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
||||||
|
|
||||||
collector: list[dict] = []
|
|
||||||
set_tool_result_collector(collector)
|
|
||||||
try:
|
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
||||||
|
|
||||||
for _ in range(max_iter):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
# Stream the final answer — don't keep the ainvoke result.
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
if chunk.content:
|
|
||||||
yield str(chunk.content)
|
|
||||||
return
|
|
||||||
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
# Execute each requested tool call
|
|
||||||
tool_map = {t.name: t for t in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
|
||||||
if tool_fn is None:
|
|
||||||
result = f"Unknown tool: {call['name']}"
|
|
||||||
else:
|
|
||||||
result = await tool_fn.ainvoke(call["args"])
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exhausted iterations — stream a final answer without tools
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
if chunk.content:
|
|
||||||
yield str(chunk.content)
|
|
||||||
finally:
|
|
||||||
clear_tool_result_collector()
|
|
||||||
self.tool_results = collector
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRegistry:
|
|
||||||
"""Singleton registry for ChatAgent subclasses."""
|
|
||||||
|
|
||||||
_instance: AgentRegistry | None = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._agents: dict[str, type[ChatAgent]] = {}
|
|
||||||
|
|
||||||
def __new__(cls) -> AgentRegistry:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
cls._instance._agents = {}
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
# ── public API ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
|
||||||
"""Class decorator — registers an agent by its name."""
|
|
||||||
instance = agent_class()
|
|
||||||
name = instance.get_name()
|
|
||||||
self._agents[name] = agent_class
|
|
||||||
return agent_class
|
|
||||||
|
|
||||||
def get(self, name: str) -> ChatAgent:
|
|
||||||
"""Return a fresh instance of the named agent."""
|
|
||||||
cls = self._agents.get(name)
|
|
||||||
if cls is None:
|
|
||||||
raise KeyError(f"Agent not found: {name}")
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
def list_agents(self) -> list[dict[str, str]]:
|
|
||||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
|
||||||
result: list[dict[str, str]] = []
|
|
||||||
for cls in self._agents.values():
|
|
||||||
inst = cls()
|
|
||||||
result.append(
|
|
||||||
{"name": inst.get_name(), "description": inst.get_description()}
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def call_agent(
|
|
||||||
self, name: str, query: str, context: dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Instantiate the named agent and call its ``handle`` method."""
|
|
||||||
agent = self.get(name)
|
|
||||||
return await agent.handle(query, context)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
registry = AgentRegistry()
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
222
app/core/brief_agent.py
Normal file
222
app/core/brief_agent.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""Brief agent — produces plain-text home and project status briefs.
|
||||||
|
|
||||||
|
Read-only tool subset only. Never calls _normalize_tagged_list_lines —
|
||||||
|
the brief prompt forbids XML tags, so skipping post-processing is intentional.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from datetime import date
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.note_agent import NOTE_READ_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_READ_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_READ_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
|
||||||
|
from app.core.deep_agent import (
|
||||||
|
_language_instruction,
|
||||||
|
_proactive_hints_injection,
|
||||||
|
_read_only_memory_tools,
|
||||||
|
_relational_memory_injection,
|
||||||
|
_run_single_agent_stream,
|
||||||
|
_trace_id_from_context,
|
||||||
|
)
|
||||||
|
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
|
||||||
|
|
||||||
|
_LANGUAGE_NAMES: dict[str, str] = {
|
||||||
|
"en": "English", "it": "Italian", "es": "Spanish",
|
||||||
|
"fr": "French", "de": "German",
|
||||||
|
"english": "English", "italian": "Italian", "italiano": "Italian",
|
||||||
|
"spanish": "Spanish", "español": "Spanish",
|
||||||
|
"french": "French", "français": "French",
|
||||||
|
"german": "German", "deutsch": "German",
|
||||||
|
}
|
||||||
|
|
||||||
|
_HOME_BRIEF_FALLBACK = """\
|
||||||
|
You are the user's personal assistant producing a short daily brief.
|
||||||
|
|
||||||
|
ROLE
|
||||||
|
Act like a calm, attentive secretary writing a stand-up note for your boss.
|
||||||
|
Warm and human, never breezy. Never cheerful filler, never emojis, never
|
||||||
|
"here is your brief" meta-text. The user is opening the app mid-workday and
|
||||||
|
is probably stressed — your job is to lower cognitive load, not add noise.
|
||||||
|
|
||||||
|
TOOLS — always call before writing
|
||||||
|
Pull fresh data every run. Do not invent counts or titles. Use at minimum:
|
||||||
|
- list_tasks_due_today — tasks the user owes today
|
||||||
|
- list_timelines_today — events starting or ending today
|
||||||
|
- list_all_projects — projects currently in progress or at risk
|
||||||
|
- memory_list_blocks / memory_get — personal context about people, clients,
|
||||||
|
payment habits, working preferences
|
||||||
|
If a tool returns nothing, simply omit that topic. Never report zeros.
|
||||||
|
|
||||||
|
WHAT TO INCLUDE
|
||||||
|
1. Tasks due today (title + priority; group the 1-2 most important).
|
||||||
|
2. Timeline events starting or ending today (and anything that starts/ends
|
||||||
|
tomorrow if the user has a very light day).
|
||||||
|
3. Active projects that need a nudge — stalled, blocked, or awaiting input.
|
||||||
|
4. Memory-aware colour where it sharpens the brief. Examples:
|
||||||
|
- "Client Rossi tends to pay late — the Acme invoice is 6 days out."
|
||||||
|
- "You usually dislike meetings before 10:00 — the call at 09:30 is unusual."
|
||||||
|
Only add a memory line when it changes what the user does. Do not pad.
|
||||||
|
|
||||||
|
WHAT TO OMIT
|
||||||
|
- Zero-counts ("no overdue items", "0 meetings today").
|
||||||
|
- Statistics ("2 active projects, 3 completed tasks").
|
||||||
|
- Headers, titles, greetings, sign-offs, dates, emojis, slang.
|
||||||
|
- Meta-phrases ("here is", "let me know if", "hope this helps").
|
||||||
|
- XML/HTML tags of any kind. Plain prose only.
|
||||||
|
|
||||||
|
LIGHT-DAY CLAUSE
|
||||||
|
If tasks + events + active-project-nudges together produce fewer than two
|
||||||
|
sentences of content, also list 1-2 projects in status on_hold or waiting
|
||||||
|
and ask a single, specific question about them — e.g. "Is the Bianchi
|
||||||
|
redesign still paused, or ready to pick back up?" One question max, grounded
|
||||||
|
in a real project name.
|
||||||
|
|
||||||
|
VOICE
|
||||||
|
- Calm. Concise. Human. Short sentences.
|
||||||
|
- Use **bold** sparingly for task titles, project names, and people's names.
|
||||||
|
- No bullet lists. Flow as 2-4 sentences of prose.
|
||||||
|
|
||||||
|
LENGTH
|
||||||
|
2-4 sentences total. Hard cap 4. If the day is truly empty, one sentence.
|
||||||
|
|
||||||
|
Respond in the user's language ({language}). Today is {today}.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PROJECT_BRIEF_FALLBACK = """\
|
||||||
|
You are the project assistant producing a short status brief for ONE project.
|
||||||
|
|
||||||
|
ROLE
|
||||||
|
A senior project manager summarising state-of-play for the owner. Factual,
|
||||||
|
sharp, forward-looking. Never reassuring filler, never emojis.
|
||||||
|
|
||||||
|
SCOPE
|
||||||
|
Work only with project_id = {project_id}. Do not mention or pull data from
|
||||||
|
other projects. Use tools to fetch fresh data:
|
||||||
|
- get_project — current status, dates, description
|
||||||
|
- list_tasks(project_id) — open work, split by status
|
||||||
|
- list_timelines(project_id) — milestones hit, upcoming, overdue
|
||||||
|
- list_notes(project_id) — any recent decisions or blockers
|
||||||
|
- memory_get — relevant context about the client, collaborators, constraints
|
||||||
|
|
||||||
|
STRUCTURE — follow exactly, one short paragraph per section, no headers
|
||||||
|
1. **State.** One sentence: current phase, health (on track / at risk / blocked),
|
||||||
|
and why. Cite the concrete signal (overdue milestone, stalled tasks, recent
|
||||||
|
blocker note).
|
||||||
|
2. **What's moving.** What was completed or progressed recently. Name specific
|
||||||
|
tasks or milestones.
|
||||||
|
3. **Next steps.** The 1-3 most important things the user should do next, in
|
||||||
|
priority order. Be concrete — task name, who owns it, when due if known.
|
||||||
|
If waiting on someone else, name them and what the ask is.
|
||||||
|
4. **Risks / memory-flagged items.** One line max. Only include when there is
|
||||||
|
a real risk or a relevant memory (e.g. late-paying client, tight deadline,
|
||||||
|
scope change). Omit the section entirely if nothing to say.
|
||||||
|
|
||||||
|
WHAT TO OMIT
|
||||||
|
- Zero-counts ("no overdue tasks").
|
||||||
|
- Generic advice ("keep up the good work").
|
||||||
|
- Greetings, headers, bullet lists, emojis, sign-offs, meta-phrases.
|
||||||
|
- XML/HTML tags or bracketed id lists. Plain prose only.
|
||||||
|
|
||||||
|
VOICE
|
||||||
|
- Direct. Factual. No fluff.
|
||||||
|
- Use **bold** sparingly for task titles, milestone names, and the owner's name.
|
||||||
|
- Short sentences. Prefer verbs over nouns ("Client review is blocking release"
|
||||||
|
not "There is a blocker which is the client review").
|
||||||
|
|
||||||
|
LENGTH
|
||||||
|
4-8 sentences total across the 3-4 sections. Hard cap 8.
|
||||||
|
|
||||||
|
Respond in the user's language ({language}). Today is {today}.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_language(context: dict[str, Any]) -> str:
|
||||||
|
core = context.get("core_memory") or {}
|
||||||
|
raw = (core.get("language") or "en").strip().lower()
|
||||||
|
return _LANGUAGE_NAMES.get(raw, raw.title()) or "English"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_read_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
return [
|
||||||
|
*TASK_READ_TOOLS,
|
||||||
|
*PROJECT_READ_TOOLS,
|
||||||
|
*TIMELINE_READ_TOOLS,
|
||||||
|
*NOTE_READ_TOOLS,
|
||||||
|
*_read_only_memory_tools(user_id, trace_id),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home_brief(
|
||||||
|
user_id: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Stream a plain-text daily home brief.
|
||||||
|
|
||||||
|
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||||
|
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||||
|
"""
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
today = date.today().isoformat()
|
||||||
|
language = _resolve_language(context)
|
||||||
|
|
||||||
|
raw_template, langfuse_prompt = get_prompt_or_fallback("home_brief", _HOME_BRIEF_FALLBACK)
|
||||||
|
system_prompt = compile_prompt(raw_template, langfuse_prompt, language=language, today=today)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
|
system_prompt += _proactive_hints_injection(context)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
|
if today not in system_prompt:
|
||||||
|
system_prompt += f"\nToday is {today}."
|
||||||
|
|
||||||
|
tools = _build_read_tools(user_id, trace_id)
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message="Generate the daily brief.",
|
||||||
|
context=context,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="brief-agent",
|
||||||
|
tools=tools,
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def run_project_brief(
|
||||||
|
user_id: str,
|
||||||
|
project_id: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Stream a plain-text project status brief for project_id.
|
||||||
|
|
||||||
|
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||||
|
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||||
|
"""
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
today = date.today().isoformat()
|
||||||
|
language = _resolve_language(context)
|
||||||
|
|
||||||
|
raw_template, langfuse_prompt = get_prompt_or_fallback("project_brief", _PROJECT_BRIEF_FALLBACK)
|
||||||
|
system_prompt = compile_prompt(
|
||||||
|
raw_template, langfuse_prompt,
|
||||||
|
language=language, today=today, project_id=project_id,
|
||||||
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
|
system_prompt += _proactive_hints_injection(context)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
|
if today not in system_prompt:
|
||||||
|
system_prompt += f"\nToday is {today}."
|
||||||
|
|
||||||
|
tools = _build_read_tools(user_id, trace_id)
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=f"Generate the project status brief for project {project_id}.",
|
||||||
|
context=context,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="brief-agent",
|
||||||
|
tools=tools,
|
||||||
|
):
|
||||||
|
yield event
|
||||||
1068
app/core/deep_agent.py
Normal file
1068
app/core/deep_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,20 +3,15 @@
|
|||||||
Maintains in-memory state for all active Electron → backend WebSocket
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
connections. One connection per user (latest replaces previous).
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
The manager participates in two interaction patterns:
|
The manager handles the **tool-call round-trip** pattern:
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes the action →
|
||||||
|
returns ``tool_result`` frame.
|
||||||
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
|
receive the result dict from Electron.
|
||||||
|
|
||||||
1. **Tool-call round-trip** (bidirectional CRUD):
|
This pattern is used by all tools (CRUD, file-system, etc.) via
|
||||||
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
``execute_on_client()`` in ``ws_context.py``.
|
||||||
``tool_result`` frame.
|
|
||||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
|
||||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
|
||||||
receive the result dict from Electron.
|
|
||||||
|
|
||||||
2. **Agent-data streaming** (local directory agent runs):
|
|
||||||
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
|
||||||
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
|
||||||
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
|
||||||
a specific ``run_id`` so the agent runner can iterate frames.
|
|
||||||
|
|
||||||
The ``device_manager`` module-level singleton is imported by both the
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
device WS route and the agent runner.
|
device WS route and the agent runner.
|
||||||
@@ -42,8 +37,6 @@ class DeviceConnection:
|
|||||||
device_id: str
|
device_id: str
|
||||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
# Per-run queues for agent_data / agent_complete frames.
|
|
||||||
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceConnectionManager:
|
class DeviceConnectionManager:
|
||||||
@@ -153,31 +146,6 @@ class DeviceConnectionManager:
|
|||||||
if fut is not None and not fut.done():
|
if fut is not None and not fut.done():
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
|
|
||||||
# ── Agent-data queue ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def get_agent_data_queue(
|
|
||||||
self, user_id: str, run_id: str
|
|
||||||
) -> asyncio.Queue[dict | None]:
|
|
||||||
"""Return (creating if absent) the queue for *run_id* agent frames.
|
|
||||||
|
|
||||||
The agent runner reads from this queue. The device WS handler writes
|
|
||||||
to it. ``None`` is the sentinel that signals the stream is finished.
|
|
||||||
"""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"get_agent_data_queue: user {user_id!r} is not connected"
|
|
||||||
)
|
|
||||||
if run_id not in conn.agent_data_queues:
|
|
||||||
conn.agent_data_queues[run_id] = asyncio.Queue()
|
|
||||||
return conn.agent_data_queues[run_id]
|
|
||||||
|
|
||||||
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
|
||||||
"""Remove the queue for *run_id* once a run has completed."""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn:
|
|
||||||
conn.agent_data_queues.pop(run_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — import this everywhere.
|
# Module-level singleton — import this everywhere.
|
||||||
device_manager = DeviceConnectionManager()
|
device_manager = DeviceConnectionManager()
|
||||||
|
|||||||
34
app/core/embeddings.py
Normal file
34
app/core/embeddings.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""OpenAI embedding helper for associative memory tier.
|
||||||
|
|
||||||
|
Single public function: ``embed_text(text) -> list[float] | None``.
|
||||||
|
Returns None on any failure — callers must implement a keyword fallback.
|
||||||
|
Never raises; all exceptions are logged as warnings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MAX_INPUT_CHARS = 8000
|
||||||
|
_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_text(text: str) -> list[float] | None:
|
||||||
|
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
|
||||||
|
try:
|
||||||
|
client = AsyncOpenAI()
|
||||||
|
truncated = text[:_MAX_INPUT_CHARS]
|
||||||
|
response = await client.embeddings.create(
|
||||||
|
input=truncated,
|
||||||
|
model=_EMBEDDING_MODEL,
|
||||||
|
)
|
||||||
|
result: list[float] = response.data[0].embedding
|
||||||
|
logger.debug("embeddings: embed_text dims=%d", len(result))
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("embeddings: embed_text failed: %s", exc)
|
||||||
|
return None
|
||||||
@@ -1,222 +0,0 @@
|
|||||||
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.schemas import ExecutionPlan, PlanStep
|
|
||||||
|
|
||||||
|
|
||||||
# ── Prompt Template Registry ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateRegistry:
|
|
||||||
"""Server-side store mapping template IDs to prompt text.
|
|
||||||
|
|
||||||
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
|
||||||
The actual prompt text is resolved here on the server, keeping prompt IP
|
|
||||||
out of API responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._templates: dict[str, str] = {}
|
|
||||||
|
|
||||||
def register(self, template_id: str, prompt_text: str) -> None:
|
|
||||||
self._templates[template_id] = prompt_text
|
|
||||||
|
|
||||||
def get(self, template_id: str) -> str:
|
|
||||||
"""Resolve a template ID to its prompt text.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the template is not registered.
|
|
||||||
"""
|
|
||||||
text = self._templates.get(template_id)
|
|
||||||
if text is None:
|
|
||||||
raise KeyError(f"Template not found: {template_id!r}")
|
|
||||||
return text
|
|
||||||
|
|
||||||
def has(self, template_id: str) -> bool:
|
|
||||||
return template_id in self._templates
|
|
||||||
|
|
||||||
def list_ids(self) -> list[str]:
|
|
||||||
"""Return all registered template IDs (never the text)."""
|
|
||||||
return list(self._templates.keys())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plan Builder ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlanBuilder:
|
|
||||||
"""Fluent builder for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent: str) -> None:
|
|
||||||
self._agent = agent
|
|
||||||
self._steps: list[PlanStep] = []
|
|
||||||
|
|
||||||
# ── step adders ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def add_step(
|
|
||||||
self, action: str, params: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a generic action step with optional parameters."""
|
|
||||||
self._steps.append(PlanStep(action=action, variables=params))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_llm_step(
|
|
||||||
self, template_id: str, variables: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append an LLM step referencing a server-side template by ID."""
|
|
||||||
self._steps.append(
|
|
||||||
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a step whose input comes from the output of an earlier step."""
|
|
||||||
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
|
||||||
return self
|
|
||||||
|
|
||||||
# ── build ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def build(self) -> ExecutionPlan:
|
|
||||||
"""Validate step references and return the ``ExecutionPlan``.
|
|
||||||
|
|
||||||
Raises ``ValueError`` if any ``data_from_step`` references a
|
|
||||||
non-existent or future step index.
|
|
||||||
"""
|
|
||||||
for i, step in enumerate(self._steps):
|
|
||||||
if step.data_from_step is not None:
|
|
||||||
if not (0 <= step.data_from_step < i):
|
|
||||||
raise ValueError(
|
|
||||||
f"Step {i}: data_from_step={step.data_from_step} must "
|
|
||||||
f"reference a preceding step index in range 0..{i - 1}"
|
|
||||||
)
|
|
||||||
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PlanCache:
|
|
||||||
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
|
||||||
The cache also serves as a runtime memoisation layer so that repeated
|
|
||||||
identical intent classifications can skip re-building the plan.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, maxsize: int = 1000) -> None:
|
|
||||||
self._maxsize = maxsize
|
|
||||||
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
|
||||||
|
|
||||||
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
|
||||||
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
|
||||||
if key in self._cache:
|
|
||||||
del self._cache[key] # remove so re-insertion places it at the end
|
|
||||||
elif len(self._cache) >= self._maxsize:
|
|
||||||
self._cache.popitem(last=False) # evict least-recently-used
|
|
||||||
self._cache[key] = plan
|
|
||||||
|
|
||||||
def get_plan(self, key: str) -> ExecutionPlan | None:
|
|
||||||
"""Return the cached plan for *key*, or ``None`` if not present.
|
|
||||||
|
|
||||||
Accessing a plan marks it as most-recently used.
|
|
||||||
"""
|
|
||||||
if key not in self._cache:
|
|
||||||
return None
|
|
||||||
self._cache.move_to_end(key)
|
|
||||||
return self._cache[key]
|
|
||||||
|
|
||||||
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached plans (most-recently used last)."""
|
|
||||||
return list(self._cache.values())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Module-level singletons ───────────────────────────────────────────
|
|
||||||
|
|
||||||
template_registry = PromptTemplateRegistry()
|
|
||||||
plan_cache = PlanCache()
|
|
||||||
|
|
||||||
|
|
||||||
def _register_builtin_templates() -> None:
|
|
||||||
"""Register the built-in server-side prompt templates.
|
|
||||||
|
|
||||||
These strings never leave the server. Clients only receive the IDs.
|
|
||||||
"""
|
|
||||||
_tpls: dict[str, str] = {
|
|
||||||
"tpl_task_agent_default": (
|
|
||||||
"You are a task management assistant. Help the user create, update, "
|
|
||||||
"list, and track tasks. Use correct status values (todo, in_progress, "
|
|
||||||
"done) and priority values (high, medium, low) from the workspace model."
|
|
||||||
),
|
|
||||||
"tpl_timeline_agent_default": (
|
|
||||||
"You are a project timeline assistant. Help the user create and manage "
|
|
||||||
"milestone timelines on their projects. Every timeline requires a "
|
|
||||||
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
|
||||||
),
|
|
||||||
"tpl_project_agent_default": (
|
|
||||||
"You are a project management assistant. Help the user create, find, "
|
|
||||||
"update, and archive projects. Projects have a name, an optional client, "
|
|
||||||
"and a status of either active or archived."
|
|
||||||
),
|
|
||||||
"tpl_note_agent_default": (
|
|
||||||
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
|
||||||
"and delete Markdown notes. Notes can optionally be linked to a project."
|
|
||||||
),
|
|
||||||
"tpl_task_extract_from_project": (
|
|
||||||
"Extract all actionable tasks from the provided project context. "
|
|
||||||
"Return a structured list of tasks, each with a title, inferred priority "
|
|
||||||
"(high, medium, or low), suggested status (todo), and a due_date in "
|
|
||||||
"milliseconds where a deadline can be inferred."
|
|
||||||
),
|
|
||||||
"tpl_note_weekly_summary": (
|
|
||||||
"Generate a weekly project summary note from the provided workspace data. "
|
|
||||||
"Include: tasks completed this week, tasks due soon, active projects, "
|
|
||||||
"and upcoming timelines. Format the output as clean Markdown."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for tid, text in _tpls.items():
|
|
||||||
template_registry.register(tid, text)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_playbooks() -> None:
|
|
||||||
"""Pre-build and cache the built-in playbooks."""
|
|
||||||
playbooks: list[tuple[str, ExecutionPlan]] = [
|
|
||||||
(
|
|
||||||
"create_tasks_from_project",
|
|
||||||
ExecutionPlanBuilder("project_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_task_extract_from_project",
|
|
||||||
{"source": "project_context"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"generate_weekly_note",
|
|
||||||
ExecutionPlanBuilder("note_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_note_weekly_summary",
|
|
||||||
{"period": "last_7_days"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for key, plan in playbooks:
|
|
||||||
plan_cache.cache_plan(key, plan)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialise on module load
|
|
||||||
_register_builtin_templates()
|
|
||||||
_load_playbooks()
|
|
||||||
190
app/core/langfuse_client.py
Normal file
190
app/core/langfuse_client.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
"""Langfuse observability — singleton client and prompt helpers.
|
||||||
|
|
||||||
|
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
|
||||||
|
all helpers are no-ops so the app works without Langfuse configured.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
Tracing::
|
||||||
|
|
||||||
|
from app.core.langfuse_client import get_langfuse
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
|
||||||
|
span.update(input=user_message)
|
||||||
|
# ... do work ...
|
||||||
|
span.update(output=result)
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
Prompt management::
|
||||||
|
|
||||||
|
from app.core.langfuse_client import get_prompt_or_fallback
|
||||||
|
|
||||||
|
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
|
||||||
|
# Use text as the system prompt; pass prompt_obj to generations for linking.
|
||||||
|
|
||||||
|
Linking a prompt to a generation::
|
||||||
|
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="llm-call",
|
||||||
|
model="gpt-4o",
|
||||||
|
prompt=prompt_obj, # links generation → prompt version in the UI
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=_usage(response))
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Generator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_client: Any = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_langfuse() -> Any | None:
|
||||||
|
"""Return the Langfuse singleton, or ``None`` when not configured."""
|
||||||
|
global _client, _initialized
|
||||||
|
if _initialized:
|
||||||
|
return _client
|
||||||
|
_initialized = True
|
||||||
|
|
||||||
|
from app.config.settings import settings # local import to avoid circular deps
|
||||||
|
|
||||||
|
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
||||||
|
logger.debug("langfuse: not configured — observability disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
_client = Langfuse(
|
||||||
|
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||||
|
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||||
|
host=settings.LANGFUSE_BASE_URL,
|
||||||
|
)
|
||||||
|
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_BASE_URL)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse: failed to initialize: %s", exc)
|
||||||
|
_client = None
|
||||||
|
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
||||||
|
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
|
||||||
|
|
||||||
|
Returns ``(raw_template, prompt_obj_or_None)``.
|
||||||
|
|
||||||
|
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
|
||||||
|
on it directly; use :func:`compile_prompt` instead so the correct variable
|
||||||
|
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
|
||||||
|
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
|
||||||
|
unavailable / the fetch failed. Pass this to generation observations so
|
||||||
|
Langfuse links the generation to the exact prompt version in the UI.
|
||||||
|
"""
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
return fallback, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt = lf.get_prompt(name, label="production", fallback=fallback)
|
||||||
|
# For text-type prompts .prompt holds the raw template string.
|
||||||
|
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
|
||||||
|
return raw, prompt
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
|
||||||
|
return fallback, None
|
||||||
|
|
||||||
|
|
||||||
|
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
|
||||||
|
"""Compile *template* with *variables*, choosing the right syntax.
|
||||||
|
|
||||||
|
* When *prompt_obj* is a real Langfuse prompt object, calls
|
||||||
|
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
|
||||||
|
substitution as defined in the Langfuse UI.
|
||||||
|
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
|
||||||
|
falls back to ``template.format(**variables)`` which handles the
|
||||||
|
``{variable}`` syntax used in the hardcoded fallback strings.
|
||||||
|
|
||||||
|
This keeps callers oblivious to which syntax is in use.
|
||||||
|
"""
|
||||||
|
if prompt_obj is not None:
|
||||||
|
try:
|
||||||
|
compiled = prompt_obj.compile(**variables)
|
||||||
|
# compile() returns a string for text prompts.
|
||||||
|
if isinstance(compiled, str):
|
||||||
|
return compiled
|
||||||
|
# Chat prompts return a list of dicts — join text parts.
|
||||||
|
if isinstance(compiled, list):
|
||||||
|
return "\n".join(
|
||||||
|
m.get("content", "") for m in compiled if isinstance(m, dict)
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
|
||||||
|
getattr(prompt_obj, "name", "?"),
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return template.format(**variables)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_usage(response: Any) -> dict[str, int]:
|
||||||
|
"""Extract token usage from a LangChain AI message into Langfuse format."""
|
||||||
|
meta = getattr(response, "usage_metadata", None)
|
||||||
|
if not meta:
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
"input": int(meta.get("input_tokens", 0)),
|
||||||
|
"output": int(meta.get("output_tokens", 0)),
|
||||||
|
"total": int(meta.get("total_tokens", 0)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def hash_user_id(user_id: str) -> str:
|
||||||
|
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
|
||||||
|
|
||||||
|
This avoids sending raw database UUIDs to external observability services
|
||||||
|
while still providing a stable, deterministic identifier for per-user
|
||||||
|
metrics in the Langfuse dashboard.
|
||||||
|
"""
|
||||||
|
return hashlib.sha256(user_id.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def langfuse_context(
|
||||||
|
user_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
|
||||||
|
|
||||||
|
No-op when Langfuse is not configured or parameters are empty.
|
||||||
|
"""
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf is None or (not user_id and not session_id):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import propagate_attributes
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("langfuse: propagate_attributes not available — skipping context")
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
attrs: dict[str, str] = {}
|
||||||
|
if user_id:
|
||||||
|
attrs["user_id"] = hash_user_id(user_id)
|
||||||
|
if session_id:
|
||||||
|
attrs["session_id"] = session_id
|
||||||
|
|
||||||
|
with propagate_attributes(**attrs):
|
||||||
|
yield
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
|
Every agent and the orchestrator call ``get_llm()``
|
||||||
instead of directly constructing a provider-specific class. The model string
|
instead of directly constructing a provider-specific class. The model string
|
||||||
follows the `LiteLLM model naming convention
|
follows the `LiteLLM model naming convention
|
||||||
<https://docs.litellm.ai/docs/providers>`_:
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
@@ -11,13 +11,15 @@ follows the `LiteLLM model naming convention
|
|||||||
* Ollama: ``ollama/llama3``
|
* Ollama: ``ollama/llama3``
|
||||||
* Bedrock: ``bedrock/anthropic.claude-v2``
|
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||||
|
|
||||||
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
Switch providers by changing **LLM_MODEL** in ``.env``
|
||||||
— no code changes required.
|
— no code changes required.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -32,6 +34,14 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
# Some provider responses include a plain dict in the `usage` field where a
|
||||||
|
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
@@ -86,12 +96,37 @@ def get_llm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_router_llm(
|
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||||
|
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
||||||
|
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
||||||
|
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
||||||
|
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||||
|
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||||
|
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
|
||||||
|
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||||
|
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||||
|
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||||
|
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def model_for_agent(agent_name: str) -> str:
|
||||||
|
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
|
||||||
|
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_llm(
|
||||||
|
agent_name: str,
|
||||||
*,
|
*,
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
) -> ChatOpenAI | ChatLiteLLM:
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
"""Return the lighter model used for intent classification / routing."""
|
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
|
||||||
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
|
||||||
|
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
|
||||||
|
per-agent override is left empty in ``.env``.
|
||||||
|
"""
|
||||||
|
model = model_for_agent(agent_name)
|
||||||
|
return get_llm(model=model, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
async def embed(text: str) -> list[float]:
|
||||||
|
|||||||
450
app/core/memory_extraction.py
Normal file
450
app/core/memory_extraction.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
"""Mem0-style Extract/Update pipeline — Phase 2.
|
||||||
|
|
||||||
|
Runs after every ``store_episode`` call to distil durable facts, preferences,
|
||||||
|
routines, and relations from the latest conversation turn.
|
||||||
|
|
||||||
|
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
|
||||||
|
|
||||||
|
Design notes
|
||||||
|
------------
|
||||||
|
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
|
||||||
|
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
|
||||||
|
- Zero-trust: never logs decrypted user content; relation subject/object labels are
|
||||||
|
treated as identifiers (safe to log per spec).
|
||||||
|
- Must not raise into the request path — caller wraps in asyncio.create_task().
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
|
||||||
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
|
||||||
|
|
||||||
|
_EXTRACTION_FALLBACK = (
|
||||||
|
"You are a memory extractor for a personal AI secretary. Given the last conversation "
|
||||||
|
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
|
||||||
|
"preferences, routines, and person/project relations worth remembering.\n\n"
|
||||||
|
"Output JSON matching this schema exactly:\n"
|
||||||
|
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
|
||||||
|
'"content": "<short canonical statement>", '
|
||||||
|
'"target_tier": "<core|associative|relational|proactive>", '
|
||||||
|
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
|
||||||
|
"Rules:\n"
|
||||||
|
"- Skip small talk, greetings, one-off questions.\n"
|
||||||
|
"- Max 5 candidates per call.\n"
|
||||||
|
"- Only extract durable information (still true next week).\n"
|
||||||
|
"- For type=relation: subject/predicate/object required.\n"
|
||||||
|
"- Default confidence=0.7.\n\n"
|
||||||
|
"## Last turn\n{last_turn}\n\n"
|
||||||
|
"## Core memory (current)\n{core_memory}\n\n"
|
||||||
|
"## Recent episodes\n{recent_episodes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_DECIDE_FALLBACK = (
|
||||||
|
"You are a memory update decision engine. Given a new memory candidate and a list of "
|
||||||
|
"existing memories from the same tier, decide what action to take.\n\n"
|
||||||
|
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
|
||||||
|
"- ADD: new information not in existing memories.\n"
|
||||||
|
"- UPDATE: contradicts or supersedes an existing memory.\n"
|
||||||
|
"- DELETE: states something is no longer true.\n"
|
||||||
|
"- NOOP: already captured accurately.\n\n"
|
||||||
|
"## New candidate\n{candidate}\n\n"
|
||||||
|
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pydantic schemas ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class MemoryCandidate(BaseModel):
|
||||||
|
type: Literal["fact", "preference", "relation", "routine"]
|
||||||
|
content: str
|
||||||
|
target_tier: Literal["core", "associative", "relational", "proactive"]
|
||||||
|
subject: str | None = None
|
||||||
|
predicate: str | None = None
|
||||||
|
object: str | None = None
|
||||||
|
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionResult(BaseModel):
|
||||||
|
candidates: list[MemoryCandidate] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def extract_candidates(
|
||||||
|
last_turn: str,
|
||||||
|
core_memory: dict[str, str],
|
||||||
|
recent_episodes: list[str],
|
||||||
|
) -> ExtractionResult:
|
||||||
|
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
|
||||||
|
|
||||||
|
Returns an ExtractionResult (may be empty on failure — never raises).
|
||||||
|
"""
|
||||||
|
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
|
||||||
|
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
|
||||||
|
|
||||||
|
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
|
||||||
|
|
||||||
|
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
|
||||||
|
if prompt_obj is not None:
|
||||||
|
try:
|
||||||
|
system_text = prompt_obj.compile(
|
||||||
|
last_turn=last_turn,
|
||||||
|
core_memory=core_str,
|
||||||
|
recent_episodes=episodes_str,
|
||||||
|
)
|
||||||
|
if isinstance(system_text, list):
|
||||||
|
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: compile failed: %s", exc)
|
||||||
|
system_text = template.format(
|
||||||
|
last_turn=last_turn,
|
||||||
|
core_memory=core_str,
|
||||||
|
recent_episodes=episodes_str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_text = template.format(
|
||||||
|
last_turn=last_turn,
|
||||||
|
core_memory=core_str,
|
||||||
|
recent_episodes=episodes_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||||
|
# Bind JSON mode so the model always returns parseable output.
|
||||||
|
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
try:
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=system_text),
|
||||||
|
HumanMessage(content="Extract memory candidates as JSON."),
|
||||||
|
]
|
||||||
|
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="memory-extraction",
|
||||||
|
model=model_for_agent("memory-extractor"),
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm_json.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm_json.ainvoke(messages)
|
||||||
|
|
||||||
|
raw = json.loads(response.content)
|
||||||
|
result = ExtractionResult.model_validate(raw)
|
||||||
|
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
|
||||||
|
return ExtractionResult(candidates=[])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def decide_action(
|
||||||
|
candidate: MemoryCandidate,
|
||||||
|
existing: list[str],
|
||||||
|
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
|
||||||
|
"""Decide what to do with a candidate given existing memories in the same tier.
|
||||||
|
|
||||||
|
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
|
||||||
|
Never raises.
|
||||||
|
"""
|
||||||
|
if not existing:
|
||||||
|
return "ADD"
|
||||||
|
|
||||||
|
candidate_str = f"[{candidate.type}] {candidate.content}"
|
||||||
|
existing_str = "\n".join(f"- {m}" for m in existing)
|
||||||
|
|
||||||
|
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
|
||||||
|
|
||||||
|
if prompt_obj is not None:
|
||||||
|
try:
|
||||||
|
system_text = prompt_obj.compile(
|
||||||
|
candidate=candidate_str,
|
||||||
|
existing_memories=existing_str,
|
||||||
|
)
|
||||||
|
if isinstance(system_text, list):
|
||||||
|
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: decide compile failed: %s", exc)
|
||||||
|
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||||
|
else:
|
||||||
|
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||||
|
lf = get_langfuse()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=system_text),
|
||||||
|
HumanMessage(content="Decide action."),
|
||||||
|
]
|
||||||
|
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="memory-decide-action",
|
||||||
|
model=model_for_agent("memory-extractor"),
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
verb = response.content.strip().upper()
|
||||||
|
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
|
||||||
|
return verb # type: ignore[return-value]
|
||||||
|
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
|
||||||
|
return "ADD"
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: decide_action failed: %s", exc)
|
||||||
|
return "ADD"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
|
||||||
|
|
||||||
|
async def run_extraction(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
last_user_msg: str,
|
||||||
|
last_assistant_msg: str,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Full Mem0-style extract/update pipeline for one conversation turn.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Load core memory + last 5 episodes.
|
||||||
|
2. extract_candidates() → up to 5 MemoryCandidate objects.
|
||||||
|
3. For each candidate: find top-3 neighbours → decide_action() → apply.
|
||||||
|
4. Trace via Langfuse.
|
||||||
|
|
||||||
|
Never raises — wraps everything in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_extraction_inner(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
last_user_msg: str,
|
||||||
|
last_assistant_msg: str,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db)
|
||||||
|
fernet = await middleware._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. Load context
|
||||||
|
core: dict[str, str] = await middleware._load_core(user_id, fernet)
|
||||||
|
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
|
|
||||||
|
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
|
||||||
|
async def _run(trace_id: str | None) -> dict[str, Any]:
|
||||||
|
# 2. Extract candidates
|
||||||
|
result = await extract_candidates(last_turn, core, episodes)
|
||||||
|
if not result.candidates:
|
||||||
|
logger.info("memory_extraction: no candidates user=%s", user_id)
|
||||||
|
return {"candidates": 0, "applied": 0}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"memory_extraction: processing %d candidates user=%s trace=%s",
|
||||||
|
len(result.candidates),
|
||||||
|
user_id,
|
||||||
|
trace_id or "-",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Apply each candidate
|
||||||
|
applied = 0
|
||||||
|
actions: list[str] = []
|
||||||
|
for candidate in result.candidates:
|
||||||
|
try:
|
||||||
|
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
|
||||||
|
applied += 1
|
||||||
|
actions.append(f"{candidate.type}:{candidate.target_tier}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_extraction: apply failed candidate=%r user=%s: %s",
|
||||||
|
candidate.content[:80],
|
||||||
|
user_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"memory_extraction: applied %d/%d candidates user=%s",
|
||||||
|
applied,
|
||||||
|
len(result.candidates),
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
|
||||||
|
|
||||||
|
with langfuse_context(user_id=user_id, session_id=session_id):
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name="memory-extraction-pipeline",
|
||||||
|
input={"last_turn_preview": last_turn[:200]},
|
||||||
|
) as span:
|
||||||
|
summary = await _run(trace_id=span.id)
|
||||||
|
span.update(output=summary)
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await _run(trace_id=None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _apply_candidate(
|
||||||
|
middleware: Any,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Any,
|
||||||
|
candidate: MemoryCandidate,
|
||||||
|
trace_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Fetch neighbours, decide action, apply to the appropriate tier."""
|
||||||
|
|
||||||
|
neighbours: list[str] = []
|
||||||
|
|
||||||
|
if candidate.target_tier == "core":
|
||||||
|
# For core tier: neighbours are existing core block values for similar keys.
|
||||||
|
blocks = await middleware.list_core_blocks(user_id)
|
||||||
|
neighbours = [b["value"] for b in blocks[:3]]
|
||||||
|
|
||||||
|
elif candidate.target_tier == "associative":
|
||||||
|
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
|
||||||
|
|
||||||
|
elif candidate.target_tier == "relational":
|
||||||
|
# Relation candidates handled specially — passed to upsert_relation directly.
|
||||||
|
# Neighbours: search by subject label if available.
|
||||||
|
neighbours = []
|
||||||
|
|
||||||
|
elif candidate.target_tier == "proactive":
|
||||||
|
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
|
||||||
|
|
||||||
|
action = await decide_action(candidate, neighbours)
|
||||||
|
logger.info(
|
||||||
|
"memory_extraction: candidate type=%s tier=%s action=%s",
|
||||||
|
candidate.type,
|
||||||
|
candidate.target_tier,
|
||||||
|
action,
|
||||||
|
)
|
||||||
|
|
||||||
|
if action == "NOOP":
|
||||||
|
return
|
||||||
|
|
||||||
|
if candidate.target_tier == "relational":
|
||||||
|
# Always upsert relations — decide_action skipped (no neighbour search).
|
||||||
|
if candidate.subject and candidate.predicate and candidate.object:
|
||||||
|
await _upsert_relation(
|
||||||
|
middleware, db, user_id, candidate, trace_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if action in ("ADD", "UPDATE"):
|
||||||
|
if candidate.target_tier == "core":
|
||||||
|
# Derive a short key from the content (first 40 chars, snake_cased).
|
||||||
|
key = _content_to_key(candidate.content)
|
||||||
|
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
|
||||||
|
|
||||||
|
elif candidate.target_tier == "associative":
|
||||||
|
await middleware.store_associative(user_id, candidate.content)
|
||||||
|
|
||||||
|
elif candidate.target_tier == "proactive":
|
||||||
|
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
|
||||||
|
|
||||||
|
elif action == "DELETE":
|
||||||
|
if candidate.target_tier == "core":
|
||||||
|
key = _content_to_key(candidate.content)
|
||||||
|
await middleware.delete_core(user_id, key)
|
||||||
|
|
||||||
|
|
||||||
|
def _content_to_key(content: str) -> str:
|
||||||
|
"""Derive a short snake_case key from a content string (first 40 chars)."""
|
||||||
|
import re # noqa: PLC0415
|
||||||
|
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
|
||||||
|
return slug or "memory"
|
||||||
|
|
||||||
|
|
||||||
|
async def _upsert_relation(
|
||||||
|
middleware: Any,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
candidate: MemoryCandidate,
|
||||||
|
trace_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
|
||||||
|
await middleware.upsert_relation(
|
||||||
|
user_id=user_id,
|
||||||
|
subject=candidate.subject or "unknown",
|
||||||
|
subject_type="unknown",
|
||||||
|
predicate=candidate.predicate or "related_to",
|
||||||
|
object_=candidate.object or "unknown",
|
||||||
|
object_type="unknown",
|
||||||
|
confidence=candidate.confidence,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
|
||||||
|
candidate.subject,
|
||||||
|
candidate.predicate,
|
||||||
|
candidate.object,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _store_proactive_stub(
|
||||||
|
middleware: Any,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
candidate: MemoryCandidate,
|
||||||
|
fernet: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Store a proactive pattern row directly (MemoryProactive model)."""
|
||||||
|
import uuid # noqa: PLC0415
|
||||||
|
from app.models import MemoryProactive # noqa: PLC0415
|
||||||
|
from app.core.memory_middleware import _encrypt # noqa: PLC0415
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, candidate.content)
|
||||||
|
row = MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
pattern_encrypted=encrypted,
|
||||||
|
confidence=candidate.confidence,
|
||||||
|
source="inferred",
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_extraction: store proactive failed: %s", exc)
|
||||||
|
await db.rollback()
|
||||||
581
app/core/memory_maintenance.py
Normal file
581
app/core/memory_maintenance.py
Normal file
@@ -0,0 +1,581 @@
|
|||||||
|
"""Memory maintenance jobs — Phase 3/5.
|
||||||
|
|
||||||
|
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
||||||
|
|
||||||
|
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
||||||
|
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
|
||||||
|
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
||||||
|
|
||||||
|
All are safe to call manually or from tests; they never raise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||||
|
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Decay parameters for relations
|
||||||
|
_DECAY_FACTOR = 0.95
|
||||||
|
_DECAY_PERIOD_DAYS = 30
|
||||||
|
_PRUNE_THRESHOLD = 0.2
|
||||||
|
|
||||||
|
# Proactive pattern decay: 10 % per 7 days since last sighting
|
||||||
|
_PROACTIVE_DECAY_FACTOR = 0.9
|
||||||
|
_PROACTIVE_DECAY_PERIOD_DAYS = 7
|
||||||
|
_PROACTIVE_PRUNE_THRESHOLD = 0.2
|
||||||
|
|
||||||
|
# Mining: require at least this many episodes to attempt pattern extraction
|
||||||
|
_MIN_EPISODES_FOR_MINING = 3
|
||||||
|
_MINING_LOOKBACK_DAYS = 30
|
||||||
|
|
||||||
|
# Audit: caps to control token cost
|
||||||
|
_AUDIT_MAX_FACTS = 50
|
||||||
|
_AUDIT_MAX_LABELS = 100
|
||||||
|
|
||||||
|
|
||||||
|
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Apply confidence decay to all relation rows for a user.
|
||||||
|
|
||||||
|
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
||||||
|
Rows whose confidence falls below 0.2 are deleted.
|
||||||
|
|
||||||
|
Never raises — wraps in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _decay_relations_inner(db, user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
deleted = 0
|
||||||
|
decayed = 0
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
reference = row.last_confirmed_at or row.created_at
|
||||||
|
if reference is None:
|
||||||
|
continue
|
||||||
|
if reference.tzinfo is None:
|
||||||
|
reference = reference.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
days_elapsed = (now - reference).days
|
||||||
|
if days_elapsed < _DECAY_PERIOD_DAYS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
||||||
|
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
||||||
|
|
||||||
|
if new_confidence < _PRUNE_THRESHOLD:
|
||||||
|
await db.delete(row)
|
||||||
|
deleted += 1
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
||||||
|
"confidence=%.3f (below threshold)",
|
||||||
|
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
row.confidence = new_confidence
|
||||||
|
decayed += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
||||||
|
user_id, decayed, deleted,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
async def drain_extraction_queue(db: AsyncSession) -> None:
|
||||||
|
"""Process pending ExtractionQueue rows for Free-tier users.
|
||||||
|
|
||||||
|
Each row corresponds to a stored episode that should be fed through the
|
||||||
|
Mem0-style extraction pipeline. Rows are deleted after successful processing.
|
||||||
|
Never raises — wraps in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _drain_extraction_queue_inner(db)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
|
||||||
|
from app.models import ExtractionQueue # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(select(ExtractionQueue))
|
||||||
|
rows = result.scalars().all()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
|
||||||
|
|
||||||
|
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
await run_extraction(
|
||||||
|
db=db,
|
||||||
|
user_id=row.user_id,
|
||||||
|
last_user_msg="",
|
||||||
|
last_assistant_msg="",
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
await db.delete(row)
|
||||||
|
await db.commit()
|
||||||
|
processed += 1
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: drain failed row=%s user=%s: %s",
|
||||||
|
row.id, row.user_id, exc,
|
||||||
|
)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
|
||||||
|
|
||||||
|
|
||||||
|
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Gate on proactive_mining tier feature.
|
||||||
|
2. Load + decrypt last 30 days of episodic summaries.
|
||||||
|
3. Call gpt-4o-mini to identify recurring patterns.
|
||||||
|
4. Encrypt and store each pattern in memory_proactive.
|
||||||
|
5. Apply decay to existing proactive rows.
|
||||||
|
|
||||||
|
Never raises — wraps in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _mine_proactive_patterns_inner(db, user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
tier = await tier_manager.get_tier(user_id, db)
|
||||||
|
if not tier_manager.check_feature(tier, "proactive_mining"):
|
||||||
|
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load user Fernet key
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
fernet = Fernet(user.encryption_key.encode())
|
||||||
|
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
|
||||||
|
|
||||||
|
episodes_result = await db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(
|
||||||
|
MemoryEpisodic.user_id == user_id,
|
||||||
|
MemoryEpisodic.created_at >= cutoff,
|
||||||
|
)
|
||||||
|
.order_by(MemoryEpisodic.created_at.asc())
|
||||||
|
)
|
||||||
|
episode_rows = episodes_result.scalars().all()
|
||||||
|
|
||||||
|
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
|
||||||
|
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
summaries: list[str] = []
|
||||||
|
for ep in episode_rows:
|
||||||
|
try:
|
||||||
|
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
|
||||||
|
summaries.append(plaintext)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not summaries:
|
||||||
|
return
|
||||||
|
|
||||||
|
patterns = await _extract_proactive_patterns(summaries)
|
||||||
|
if not patterns:
|
||||||
|
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
stored = 0
|
||||||
|
for pattern_text in patterns:
|
||||||
|
try:
|
||||||
|
encrypted = fernet.encrypt(pattern_text.encode()).decode()
|
||||||
|
row = MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
pattern_encrypted=encrypted,
|
||||||
|
confidence=0.7,
|
||||||
|
source="inferred",
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
stored += 1
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
|
||||||
|
user_id, stored,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
|
||||||
|
await db.rollback()
|
||||||
|
return
|
||||||
|
|
||||||
|
await _decay_proactive_patterns(db, user_id, fernet)
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
|
||||||
|
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
|
||||||
|
from app.core.llm import get_agent_llm # noqa: PLC0415
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-miner", temperature=0)
|
||||||
|
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
|
||||||
|
prompt = (
|
||||||
|
"You are analyzing conversation history for a personal AI secretary. "
|
||||||
|
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
|
||||||
|
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
|
||||||
|
"Return each pattern as a plain, short English sentence on its own line. "
|
||||||
|
"No numbering, no bullet points, no extra text.\n\n"
|
||||||
|
f"Conversation history:\n{combined}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke(prompt)
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
|
||||||
|
return lines[:5]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
|
||||||
|
"""Decay confidence of existing proactive patterns; prune below threshold."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
deleted = 0
|
||||||
|
decayed = 0
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
reference = row.created_at
|
||||||
|
if reference is None:
|
||||||
|
continue
|
||||||
|
if reference.tzinfo is None:
|
||||||
|
reference = reference.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
days_elapsed = (now - reference).days
|
||||||
|
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
|
||||||
|
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
|
||||||
|
|
||||||
|
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
|
||||||
|
await db.delete(row)
|
||||||
|
deleted += 1
|
||||||
|
else:
|
||||||
|
row.confidence = new_confidence
|
||||||
|
decayed += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
|
||||||
|
user_id, decayed, deleted,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
_AUDIT_CONTRADICTIONS_FALLBACK = (
|
||||||
|
"You are auditing a personal AI assistant's memory bank. "
|
||||||
|
"Each fact has an ID in brackets. "
|
||||||
|
"Find pairs that directly contradict each other "
|
||||||
|
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
|
||||||
|
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
|
||||||
|
'Return ONLY a valid JSON array, no markdown fences: '
|
||||||
|
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
|
||||||
|
"If no contradictions, return [].\n\n"
|
||||||
|
"Facts:\n{facts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_AUDIT_CANONICALIZE_FALLBACK = (
|
||||||
|
"You are auditing entity labels in a personal AI assistant's relational memory. "
|
||||||
|
"These are names of people, companies, projects, or topics. "
|
||||||
|
"Group labels that clearly refer to the same real-world entity "
|
||||||
|
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
|
||||||
|
"Return ONLY a valid JSON array, no markdown fences: "
|
||||||
|
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
|
||||||
|
"Only include groups with at least one variant. Singletons: omit.\n\n"
|
||||||
|
"Labels:\n{labels}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def audit_memory(db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
|
||||||
|
2. LLM flags rows to delete (direct contradictions); hard-delete them.
|
||||||
|
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
|
||||||
|
4. Rewrite variant labels to their canonical form in-place.
|
||||||
|
|
||||||
|
Never raises — wraps in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _audit_memory_inner(db, user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
fernet = Fernet(user.encryption_key.encode())
|
||||||
|
await _scan_associative_contradictions(db, user_id, fernet)
|
||||||
|
await _canonicalize_relation_labels(db, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _scan_associative_contradictions(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Fernet,
|
||||||
|
) -> None:
|
||||||
|
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(_AUDIT_MAX_FACTS)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
if len(rows) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
id_to_text: dict[str, str] = {}
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
|
||||||
|
id_to_text[row.id] = plaintext
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if len(id_to_text) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
id_list = list(id_to_text.keys())
|
||||||
|
numbered = "\n".join(
|
||||||
|
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
|
||||||
|
)
|
||||||
|
|
||||||
|
template, prompt_obj = get_prompt_or_fallback(
|
||||||
|
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
|
||||||
|
)
|
||||||
|
system_text = compile_prompt(template, prompt_obj, facts=numbered)
|
||||||
|
|
||||||
|
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||||
|
lf = get_langfuse()
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=system_text),
|
||||||
|
HumanMessage(content="Audit facts for contradictions."),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="memory-audit-contradictions",
|
||||||
|
model=model_for_agent("memory-auditor"),
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
deletions = json.loads(text.strip())
|
||||||
|
if not isinstance(deletions, list):
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
|
||||||
|
user_id, exc,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
deleted = 0
|
||||||
|
for item in deletions:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
rid = item.get("delete")
|
||||||
|
if not rid or rid not in id_to_text:
|
||||||
|
continue
|
||||||
|
result2 = await db.execute(
|
||||||
|
select(MemoryAssociative).where(
|
||||||
|
MemoryAssociative.id == rid,
|
||||||
|
MemoryAssociative.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
target = result2.scalar_one_or_none()
|
||||||
|
if target:
|
||||||
|
await db.delete(target)
|
||||||
|
deleted += 1
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
|
||||||
|
rid, user_id, item.get("reason", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
if deleted:
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
all_labels: set[str] = set()
|
||||||
|
for row in rows:
|
||||||
|
all_labels.add(row.subject_label)
|
||||||
|
all_labels.add(row.object_label)
|
||||||
|
|
||||||
|
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
|
||||||
|
if len(labels_list) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
|
||||||
|
template, prompt_obj = get_prompt_or_fallback(
|
||||||
|
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
|
||||||
|
)
|
||||||
|
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
|
||||||
|
|
||||||
|
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||||
|
|
||||||
|
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||||
|
lf = get_langfuse()
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=system_text),
|
||||||
|
HumanMessage(content="Canonicalize entity labels."),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="memory-audit-canonicalize",
|
||||||
|
model=model_for_agent("memory-auditor"),
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
groups = json.loads(text.strip())
|
||||||
|
if not isinstance(groups, list):
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
|
||||||
|
user_id, exc,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build variant → canonical map
|
||||||
|
remap: dict[str, str] = {}
|
||||||
|
for group in groups:
|
||||||
|
if not isinstance(group, dict):
|
||||||
|
continue
|
||||||
|
canonical = group.get("canonical", "")
|
||||||
|
variants = group.get("variants") or []
|
||||||
|
if not canonical:
|
||||||
|
continue
|
||||||
|
for v in variants:
|
||||||
|
if isinstance(v, str) and v != canonical:
|
||||||
|
remap[v] = canonical
|
||||||
|
|
||||||
|
if not remap:
|
||||||
|
return
|
||||||
|
|
||||||
|
updated = 0
|
||||||
|
for row in rows:
|
||||||
|
changed = False
|
||||||
|
if row.subject_label in remap:
|
||||||
|
row.subject_label = remap[row.subject_label]
|
||||||
|
changed = True
|
||||||
|
if row.object_label in remap:
|
||||||
|
row.object_label = remap[row.object_label]
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
updated += 1
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
|
||||||
|
user_id, updated,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await db.rollback()
|
||||||
@@ -18,8 +18,10 @@ Usage:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from cryptography.fernet import Fernet, InvalidToken
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
@@ -27,15 +29,22 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.models import (
|
from app.models import (
|
||||||
|
ExtractionQueue,
|
||||||
MemoryAssociative,
|
MemoryAssociative,
|
||||||
MemoryCore,
|
MemoryCore,
|
||||||
MemoryEpisodic,
|
MemoryEpisodic,
|
||||||
MemoryProactive,
|
MemoryProactive,
|
||||||
|
MemoryRelation,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
# Tuning constants
|
# Tuning constants
|
||||||
_ASSOCIATIVE_TOP_K = 5
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
_EPISODIC_RECENT_N = 10
|
_EPISODIC_RECENT_N = 10
|
||||||
@@ -50,7 +59,13 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
async def enrich_context(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
@@ -58,21 +73,39 @@ class MemoryMiddleware:
|
|||||||
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||||
episodic_memory — [plaintext_summary, ...] (most recent N)
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||||
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||||
|
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
|
||||||
"""
|
"""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier: str = user_dbg.get("tier") or "free"
|
||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||||
episodic = await self._load_episodic(user_id, fernet)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
relational = await self._load_relational(user_id, user_tier=user_tier)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_tier,
|
||||||
|
len(core),
|
||||||
|
len(associative),
|
||||||
|
len(episodic),
|
||||||
|
len(proactive),
|
||||||
|
len(relational),
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"core_memory": core,
|
"core_memory": core,
|
||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
"episodic_memory": episodic,
|
"episodic_memory": episodic,
|
||||||
"proactive_hints": proactive,
|
"proactive_hints": proactive,
|
||||||
|
"relational_memory": relational,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def store_episode(
|
async def store_episode(
|
||||||
@@ -81,11 +114,15 @@ class MemoryMiddleware:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
response: str,
|
response: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||||
latency low. Full LLM summarisation can be added in a later step.
|
latency low. After committing the episode row, dispatches the Mem0-style
|
||||||
|
extraction pipeline:
|
||||||
|
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
||||||
|
- Free → enqueue an ExtractionQueue row for the daily cron.
|
||||||
"""
|
"""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -94,20 +131,97 @@ class MemoryMiddleware:
|
|||||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||||
encrypted = _encrypt(fernet, summary)
|
encrypted = _encrypt(fernet, summary)
|
||||||
|
|
||||||
row = MemoryEpisodic(
|
episode = MemoryEpisodic(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
summary_encrypted=encrypted,
|
summary_encrypted=encrypted,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
self._db.add(row)
|
self._db.add(episode)
|
||||||
|
episode_id: str = episode.id
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
tier = user_dbg.get("tier") or "free"
|
||||||
|
logger.info(
|
||||||
|
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
tier,
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
return
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
||||||
|
await self._dispatch_extraction(
|
||||||
|
user_id=user_id,
|
||||||
|
episode_id=episode_id,
|
||||||
|
last_user_msg=message,
|
||||||
|
last_assistant_msg=response,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _dispatch_extraction(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
episode_id: str,
|
||||||
|
last_user_msg: str,
|
||||||
|
last_assistant_msg: str,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Route extraction to realtime task or batch queue based on user tier."""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
tier = await tier_manager.get_tier(user_id, self._db)
|
||||||
|
|
||||||
|
if tier_manager.check_feature(tier, "realtime_extraction"):
|
||||||
|
# Pro/Power/Team: fire-and-forget in the background.
|
||||||
|
# Must open a fresh session — request session closes after handler returns.
|
||||||
|
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||||
|
from app.db import async_session # noqa: PLC0415
|
||||||
|
|
||||||
|
async def _task() -> None:
|
||||||
|
try:
|
||||||
|
async with async_session() as fresh_db:
|
||||||
|
await run_extraction(
|
||||||
|
db=fresh_db,
|
||||||
|
user_id=user_id,
|
||||||
|
last_user_msg=last_user_msg,
|
||||||
|
last_assistant_msg=last_assistant_msg,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory: extraction task failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(_task())
|
||||||
|
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
||||||
|
else:
|
||||||
|
# Free tier: enqueue for daily batch cron.
|
||||||
|
queue_row = ExtractionQueue(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
episode_id=episode_id,
|
||||||
|
)
|
||||||
|
self._db.add(queue_row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory: extraction enqueued (batch) user=%s episode=%s",
|
||||||
|
user_id,
|
||||||
|
episode_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -133,10 +247,313 @@ class MemoryMiddleware:
|
|||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
key,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||||
|
"""Return core memory as editable blocks (label/value)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore)
|
||||||
|
.where(MemoryCore.user_id == user_id)
|
||||||
|
.order_by(MemoryCore.key.asc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[dict[str, str]] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append({"label": row.key, "value": plaintext})
|
||||||
|
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||||
|
"""Return a single core memory block value by label."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
||||||
|
return None
|
||||||
|
value = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||||
|
"""Delete a core memory block by label. Returns True if deleted."""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._db.delete(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||||
|
"""Append content to a core block, creating it if missing."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None:
|
||||||
|
await self.update_core(user_id, label, content)
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
||||||
|
return
|
||||||
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
||||||
|
|
||||||
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||||
|
"""Replace one exact string inside a core block. Returns False if not found."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None or old not in current:
|
||||||
|
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
||||||
|
return False
|
||||||
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||||
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def store_associative(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
content: str,
|
||||||
|
entity_type: str | None = None,
|
||||||
|
entity_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store associative memory; embed if user tier has real_embeddings."""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||||
|
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier = user_dbg.get("tier") or "free"
|
||||||
|
|
||||||
|
embedding: list[float] | None = None
|
||||||
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||||
|
embedding = await embed_text(content)
|
||||||
|
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=embedding,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory: store_associative user=%s embedded=%s",
|
||||||
|
user_id,
|
||||||
|
embedding is not None,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def upsert_relation(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
subject: str,
|
||||||
|
subject_type: str,
|
||||||
|
predicate: str,
|
||||||
|
object_: str,
|
||||||
|
object_type: str,
|
||||||
|
*,
|
||||||
|
confidence: float = 0.7,
|
||||||
|
source_episode_id: str | None = None,
|
||||||
|
notes: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
|
||||||
|
|
||||||
|
subject_label / object_label are plaintext entity identifiers — not encrypted.
|
||||||
|
notes is optional; encrypted with user Fernet if provided.
|
||||||
|
"""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier = user_dbg.get("tier") or "free"
|
||||||
|
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||||
|
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
|
||||||
|
return
|
||||||
|
|
||||||
|
notes_encrypted: bytes | None = None
|
||||||
|
if notes:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet:
|
||||||
|
notes_encrypted = fernet.encrypt(notes.encode())
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryRelation).where(
|
||||||
|
MemoryRelation.user_id == user_id,
|
||||||
|
MemoryRelation.subject_label == subject,
|
||||||
|
MemoryRelation.predicate == predicate,
|
||||||
|
MemoryRelation.object_label == object_,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing is not None:
|
||||||
|
existing.subject_type = subject_type
|
||||||
|
existing.object_type = object_type
|
||||||
|
existing.confidence = confidence
|
||||||
|
existing.last_confirmed_at = _now()
|
||||||
|
if notes_encrypted is not None:
|
||||||
|
existing.notes_encrypted = notes_encrypted
|
||||||
|
else:
|
||||||
|
self._db.add(MemoryRelation(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
subject_label=subject,
|
||||||
|
subject_type=subject_type,
|
||||||
|
predicate=predicate,
|
||||||
|
object_label=object_,
|
||||||
|
object_type=object_type,
|
||||||
|
confidence=confidence,
|
||||||
|
source_episode_id=source_episode_id,
|
||||||
|
notes_encrypted=notes_encrypted,
|
||||||
|
))
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
|
||||||
|
user_id, subject, predicate, object_,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def query_relations(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
subject: str | None = None,
|
||||||
|
predicate: str | None = None,
|
||||||
|
object_: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[MemoryRelation]:
|
||||||
|
"""Query relation rows for a user with optional filters."""
|
||||||
|
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||||
|
if subject is not None:
|
||||||
|
q = q.where(MemoryRelation.subject_label == subject)
|
||||||
|
if predicate is not None:
|
||||||
|
q = q.where(MemoryRelation.predicate == predicate)
|
||||||
|
if object_ is not None:
|
||||||
|
q = q.where(MemoryRelation.object_label == object_)
|
||||||
|
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
|
||||||
|
result = await self._db.execute(q)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=None,
|
||||||
|
entity_type=source,
|
||||||
|
entity_id=None,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search recall memory (episodic summaries) by keyword."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
# ── Private helpers ───────────────────────────────────────────────────────
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
@@ -148,6 +565,29 @@ class MemoryMiddleware:
|
|||||||
return None
|
return None
|
||||||
return Fernet(user.encryption_key.encode())
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||||
|
"""Load lightweight user debug fields for trace logs."""
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
return {"tier": None}
|
||||||
|
|
||||||
|
sub_result = await self._db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub_tier: str | None = sub_result.scalar_one_or_none()
|
||||||
|
if sub_tier:
|
||||||
|
tier = sub_tier
|
||||||
|
elif settings.ENV == "dev":
|
||||||
|
tier = "power"
|
||||||
|
else:
|
||||||
|
tier = user.tier or "free"
|
||||||
|
|
||||||
|
return {"tier": tier}
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
@@ -161,14 +601,49 @@ class MemoryMiddleware:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_associative(
|
async def _load_associative(
|
||||||
self, user_id: str, message: str, fernet: Fernet
|
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Load top-k associative memories.
|
"""Load top-k associative memories.
|
||||||
|
|
||||||
Production: uses pgvector cosine similarity on the message embedding.
|
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
||||||
Current implementation: keyword-based fallback (no external embedding call)
|
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
||||||
so tests pass without a live OpenAI key.
|
|
||||||
"""
|
"""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||||
|
|
||||||
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||||
|
vec = await embed_text(message)
|
||||||
|
if vec is not None:
|
||||||
|
try:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(
|
||||||
|
MemoryAssociative.user_id == user_id,
|
||||||
|
MemoryAssociative.embedding.isnot(None),
|
||||||
|
)
|
||||||
|
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
||||||
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
logger.info(
|
||||||
|
"memory: _load_associative user=%s mode=vector hits=%d",
|
||||||
|
user_id,
|
||||||
|
len(out),
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory: vector search failed user=%s, falling back to keyword: %s",
|
||||||
|
user_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keyword fallback: most recent rows
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryAssociative)
|
select(MemoryAssociative)
|
||||||
.where(MemoryAssociative.user_id == user_id)
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
@@ -176,17 +651,24 @@ class MemoryMiddleware:
|
|||||||
.limit(_ASSOCIATIVE_TOP_K)
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
)
|
)
|
||||||
rows = result.scalars().all()
|
rows = result.scalars().all()
|
||||||
out: list[str] = []
|
out = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
if plaintext is not None:
|
if plaintext is not None:
|
||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_episodic(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Fernet,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
if session_id:
|
||||||
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryEpisodic)
|
query
|
||||||
.where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
.limit(_EPISODIC_RECENT_N)
|
.limit(_EPISODIC_RECENT_N)
|
||||||
)
|
)
|
||||||
@@ -198,6 +680,26 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
|
||||||
|
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryRelation)
|
||||||
|
.where(MemoryRelation.user_id == user_id)
|
||||||
|
.order_by(MemoryRelation.confidence.desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out = [
|
||||||
|
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return out
|
||||||
|
|
||||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryProactive)
|
select(MemoryProactive)
|
||||||
|
|||||||
@@ -1,210 +0,0 @@
|
|||||||
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, AsyncGenerator
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
from app.core.llm import get_router_llm
|
|
||||||
from app.core.agent_registry import registry as _default_registry
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
|
||||||
|
|
||||||
_FALLBACK_AGENT = "task_agent"
|
|
||||||
|
|
||||||
_CLASSIFY_SYSTEM = (
|
|
||||||
"You are an intent classifier. Given the user message and context, decide "
|
|
||||||
"which agent to route to.\n"
|
|
||||||
"Available agents: {agents}\n"
|
|
||||||
"Respond with just the agent name, nothing else."
|
|
||||||
)
|
|
||||||
|
|
||||||
_SYNTHESIZE_HUMAN = (
|
|
||||||
"Combine the following agent results into one coherent response.\n\n"
|
|
||||||
"Agent results:\n{results}\n\n"
|
|
||||||
"Original message: {message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_llm():
|
|
||||||
return get_router_llm()
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_intent(
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> str:
|
|
||||||
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
|
||||||
|
|
||||||
Falls back to ``task_agent`` when the registry is empty or the model
|
|
||||||
returns a name that is not registered.
|
|
||||||
"""
|
|
||||||
agents = reg.list_agents()
|
|
||||||
if not agents:
|
|
||||||
return _FALLBACK_AGENT
|
|
||||||
|
|
||||||
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
|
||||||
# Truncate context to keep the classification prompt short
|
|
||||||
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
|
||||||
|
|
||||||
llm = _make_llm()
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[SystemMessage(content=system), HumanMessage(content=human)]
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = str(response.content).strip().lower()
|
|
||||||
known = {a["name"] for a in agents}
|
|
||||||
return agent_name if agent_name in known else _FALLBACK_AGENT
|
|
||||||
|
|
||||||
|
|
||||||
async def route_single(
|
|
||||||
agent_name: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
|
||||||
response_text = await reg.call_agent(agent_name, message, context)
|
|
||||||
return ChatResponse(response=response_text)
|
|
||||||
|
|
||||||
|
|
||||||
async def route_pipeline(
|
|
||||||
agent_names: list[str],
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Execute agents sequentially; each agent receives previous results in context.
|
|
||||||
|
|
||||||
A final LLM synthesis call merges all results into one coherent response.
|
|
||||||
"""
|
|
||||||
previous_results: list[str] = []
|
|
||||||
|
|
||||||
for agent_name in agent_names:
|
|
||||||
ctx = {**context, "previous_results": list(previous_results)}
|
|
||||||
result = await reg.call_agent(agent_name, message, ctx)
|
|
||||||
previous_results.append(result)
|
|
||||||
|
|
||||||
results_str = "\n\n".join(
|
|
||||||
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
|
||||||
)
|
|
||||||
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
|
||||||
llm = _make_llm()
|
|
||||||
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
|
||||||
return ChatResponse(response=str(synthesis.content))
|
|
||||||
|
|
||||||
|
|
||||||
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
|
||||||
"""Build an ``ExecutionPlan`` for the resolved agent.
|
|
||||||
|
|
||||||
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
|
||||||
If a default template exists for the agent, an LLM step is emitted;
|
|
||||||
otherwise a plain ``handle`` action step is used.
|
|
||||||
"""
|
|
||||||
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
|
||||||
|
|
||||||
template_id = f"tpl_{agent_name}_default"
|
|
||||||
builder = ExecutionPlanBuilder(agent_name)
|
|
||||||
if template_registry.has(template_id):
|
|
||||||
builder.add_llm_step(template_id, {"message": message})
|
|
||||||
else:
|
|
||||||
builder.add_step("handle", {"message": message})
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> ChatResponse | ExecutionPlan:
|
|
||||||
"""Main orchestration entry point.
|
|
||||||
|
|
||||||
* Classifies the user's intent to select an agent.
|
|
||||||
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
|
||||||
``ChatResponse``.
|
|
||||||
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
|
||||||
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
|
|
||||||
if request.execution_mode == "direct":
|
|
||||||
return await route_single(agent_name, request.message, context, reg)
|
|
||||||
|
|
||||||
# plan mode — return plan, do not execute
|
|
||||||
return _build_plan(agent_name, request.message)
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_v3(
|
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> tuple[str, ChatAgent]:
|
|
||||||
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
|
|
||||||
|
|
||||||
Classifies intent and instantiates the matching agent. The caller is responsible
|
|
||||||
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
agent_name = await classify_intent(message, context, reg)
|
|
||||||
return agent_name, reg.get(agent_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_v3_stream(
|
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
agent_holder: list | None = None,
|
|
||||||
) -> AsyncGenerator[tuple[str, str], None]:
|
|
||||||
"""v3 streaming orchestration — yields (agent_name, token) pairs.
|
|
||||||
|
|
||||||
The first yield always carries the agent_name with an empty token so that
|
|
||||||
callers (e.g. FloatingFormatter) can detect the routing domain before any text
|
|
||||||
tokens arrive.
|
|
||||||
|
|
||||||
If *agent_holder* is provided (a list), the agent instance is appended so
|
|
||||||
callers can access ``agent.tool_results`` after the stream completes.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
agent_name = await classify_intent(message, context, reg)
|
|
||||||
agent = reg.get(agent_name)
|
|
||||||
if agent_holder is not None:
|
|
||||||
agent_holder.append(agent)
|
|
||||||
yield agent_name, "" # domain signal — no token yet
|
|
||||||
async for token in agent.handle_stream(message, context):
|
|
||||||
yield agent_name, token
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_stream(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming orchestration — yields plain text chunks only.
|
|
||||||
|
|
||||||
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
|
|
||||||
wrapping each chunk in a ``text_chunk`` frame and sending the final
|
|
||||||
``final`` frame once the generator is exhausted.
|
|
||||||
|
|
||||||
Agents do not yet support token-level streaming; the full response is
|
|
||||||
fetched first (which may involve multiple WS round-trips for tool calls),
|
|
||||||
then emitted in fixed-size chunks.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
response_text = await reg.call_agent(agent_name, request.message, context)
|
|
||||||
|
|
||||||
chunk_size = 50
|
|
||||||
for i in range(0, len(response_text), chunk_size):
|
|
||||||
yield response_text[i : i + chunk_size]
|
|
||||||
@@ -1,244 +1,47 @@
|
|||||||
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
|
"""Output formatter for deep-agent stream events."""
|
||||||
|
|
||||||
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
|
|
||||||
FloatingFormatter: produces floating_domain, stream_text, stream_end
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
WsFloatingDomain,
|
|
||||||
WsStreamBlock,
|
|
||||||
WsStreamEnd,
|
|
||||||
WsStreamStart,
|
|
||||||
WsStreamText,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
|
|
||||||
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
|
|
||||||
|
|
||||||
# Map agent name → floating domain
|
|
||||||
_AGENT_DOMAIN: dict[str, str] = {
|
|
||||||
"task_agent": "tasks",
|
|
||||||
"timeline_agent": "timelines",
|
|
||||||
"note_agent": "notes",
|
|
||||||
"project_agent": "projects",
|
|
||||||
}
|
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
|
|
||||||
|
|
||||||
|
|
||||||
class HomeFormatter:
|
class StreamFormatter:
|
||||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||||
|
|
||||||
The LLM is expected to output a newline-delimited sequence of JSON objects,
|
|
||||||
each with a ``type`` field:
|
|
||||||
- ``text`` → yields WsStreamText immediately (word-by-word)
|
|
||||||
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
|
|
||||||
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
|
|
||||||
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
|
|
||||||
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
|
|
||||||
|
|
||||||
Invalid or unknown blocks are logged and skipped — stream never crashes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
|
|
||||||
self.request_id = request_id
|
|
||||||
self.tool_results = tool_results
|
|
||||||
|
|
||||||
async def format(
|
|
||||||
self,
|
|
||||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
|
||||||
|
|
||||||
buffer = ""
|
|
||||||
async for _agent_name, token in token_stream:
|
|
||||||
if not token:
|
|
||||||
continue
|
|
||||||
buffer += token
|
|
||||||
# Flush any complete JSON objects from the buffer
|
|
||||||
async for frame in self._flush_complete_objects(buffer):
|
|
||||||
buffer = "" # reset after flush
|
|
||||||
yield frame
|
|
||||||
break # only one flush per iteration; rest accumulates
|
|
||||||
|
|
||||||
# Flush any remaining content
|
|
||||||
if buffer.strip():
|
|
||||||
async for frame in self._flush_complete_objects(buffer, final=True):
|
|
||||||
yield frame
|
|
||||||
|
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
|
||||||
|
|
||||||
async def _flush_complete_objects(
|
|
||||||
self, text: str, final: bool = False
|
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
|
||||||
"""Try to parse and yield all complete JSON objects from *text*.
|
|
||||||
|
|
||||||
Yields nothing if text is incomplete JSON (unless *final* is True,
|
|
||||||
in which case remaining text is emitted as plain stream_text).
|
|
||||||
"""
|
|
||||||
remaining = text.strip()
|
|
||||||
while remaining:
|
|
||||||
# Fast path: plain text (not JSON)
|
|
||||||
if not remaining.startswith("{"):
|
|
||||||
# Yield as plain text chunk
|
|
||||||
newline_idx = remaining.find("\n")
|
|
||||||
if newline_idx == -1:
|
|
||||||
if final:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
||||||
remaining = ""
|
|
||||||
else:
|
|
||||||
return # accumulate more
|
|
||||||
else:
|
|
||||||
line = remaining[:newline_idx].strip()
|
|
||||||
remaining = remaining[newline_idx + 1:].strip()
|
|
||||||
if line:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=line)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Try to decode a JSON object
|
|
||||||
try:
|
|
||||||
obj, end_idx = _try_parse_json(remaining)
|
|
||||||
except ValueError:
|
|
||||||
if final:
|
|
||||||
# Emit as raw text if we can't parse
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
||||||
remaining = ""
|
|
||||||
return
|
|
||||||
|
|
||||||
if obj is None:
|
|
||||||
if final:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
|
||||||
remaining = ""
|
|
||||||
return # incomplete — need more tokens
|
|
||||||
|
|
||||||
remaining = remaining[end_idx:].strip()
|
|
||||||
block_type = obj.get("type")
|
|
||||||
|
|
||||||
frame = self._dispatch_block(obj, block_type)
|
|
||||||
if frame is not None:
|
|
||||||
yield frame
|
|
||||||
|
|
||||||
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
|
|
||||||
if block_type == "text":
|
|
||||||
content = obj.get("content", "")
|
|
||||||
if content:
|
|
||||||
return WsStreamText(request_id=self.request_id, chunk=str(content))
|
|
||||||
return None
|
|
||||||
|
|
||||||
if block_type == "chart":
|
|
||||||
chart_type = obj.get("chartType")
|
|
||||||
if chart_type not in _VALID_CHART_TYPES:
|
|
||||||
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
|
|
||||||
return None
|
|
||||||
if not isinstance(obj.get("data"), list):
|
|
||||||
logger.warning("HomeFormatter: chart missing data array — skipping")
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="chart",
|
|
||||||
data=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if block_type == "entity_ref":
|
|
||||||
entity = obj.get("entity")
|
|
||||||
resolved = self._resolve_entity(entity)
|
|
||||||
if resolved is None:
|
|
||||||
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="entity_ref",
|
|
||||||
data={"entity": entity, "items": resolved},
|
|
||||||
)
|
|
||||||
|
|
||||||
if block_type == "table":
|
|
||||||
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
|
|
||||||
logger.warning("HomeFormatter: table missing headers/rows — skipping")
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="table",
|
|
||||||
data=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if block_type == "timeline":
|
|
||||||
if not isinstance(obj.get("timelines"), list):
|
|
||||||
logger.warning("HomeFormatter: timeline missing timelines — skipping")
|
|
||||||
return None
|
|
||||||
return WsStreamBlock(
|
|
||||||
request_id=self.request_id,
|
|
||||||
block_type="timeline",
|
|
||||||
data=obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
|
|
||||||
"""Find matching items in tool_results by entity type."""
|
|
||||||
if not entity:
|
|
||||||
return None
|
|
||||||
matches = [r for r in self.tool_results if r.get("entity") == entity]
|
|
||||||
return matches if matches else None
|
|
||||||
|
|
||||||
|
|
||||||
class FloatingFormatter:
|
|
||||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
|
||||||
|
|
||||||
Emits floating_domain immediately (from agent_name), then streams all tokens
|
|
||||||
as plain stream_text — no block parsing for floating context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
|
||||||
async def format(
|
async def format(
|
||||||
self,
|
self,
|
||||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
domain_sent = False
|
started = False
|
||||||
|
|
||||||
async for agent_name, token in token_stream:
|
async for event_type, data in event_stream:
|
||||||
if not domain_sent:
|
if event_type == "floating_domain":
|
||||||
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
|
if isinstance(data, dict):
|
||||||
yield WsFloatingDomain(
|
yield WsFloatingDomain(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
domain=domain, # type: ignore[arg-type]
|
domain=data,
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event_type != "token":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not started:
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
domain_sent = True
|
started = True
|
||||||
|
|
||||||
if token:
|
text = str(data or "")
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=token)
|
if text:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||||
|
|
||||||
|
if not started:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
|
|
||||||
|
|
||||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
|
|
||||||
"""Attempt to parse the first complete JSON object from *text*.
|
|
||||||
|
|
||||||
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
|
|
||||||
object is incomplete, and raises ``ValueError`` when text is not JSON.
|
|
||||||
"""
|
|
||||||
decoder = json.JSONDecoder()
|
|
||||||
try:
|
|
||||||
obj, end_idx = decoder.raw_decode(text)
|
|
||||||
if not isinstance(obj, dict):
|
|
||||||
raise ValueError("Expected JSON object")
|
|
||||||
return obj, end_idx
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
# Incomplete JSON — need more tokens
|
|
||||||
if "Unterminated" in str(exc) or exc.pos == len(text):
|
|
||||||
return None, 0
|
|
||||||
raise ValueError(str(exc)) from exc
|
|
||||||
|
|||||||
104
app/core/preprocessors/__init__.py
Normal file
104
app/core/preprocessors/__init__.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Preprocessor registry: detect content type and dispatch to handlers.
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
detect_content_type(filename, raw_content) -> str
|
||||||
|
Heuristic detection based on file extension and content patterns.
|
||||||
|
|
||||||
|
preprocess(content_type, raw_content) -> PreprocessResult
|
||||||
|
Dispatch to the appropriate handler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from app.core.preprocessors.base import PreprocessResult
|
||||||
|
|
||||||
|
# ── Heuristics ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Patterns that strongly suggest an email HTML file
|
||||||
|
_EMAIL_SIGNALS = re.compile(
|
||||||
|
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patterns that suggest a generic HTML page (not an email)
|
||||||
|
_GENERIC_HTML_SIGNALS = re.compile(
|
||||||
|
r"<(nav|main|header|footer|article|section)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_content_type(filename: str, raw_content: str) -> str:
|
||||||
|
"""Return a content-type string for the given file.
|
||||||
|
|
||||||
|
Supported types: ``"email_html"``, ``"generic_html"``,
|
||||||
|
``"plain_text"``, ``"unknown"``.
|
||||||
|
"""
|
||||||
|
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
||||||
|
|
||||||
|
if ext == "txt":
|
||||||
|
return "plain_text"
|
||||||
|
|
||||||
|
if ext in ("html", "htm", "eml", "mhtml", "mht"):
|
||||||
|
# Prefer email detection over generic HTML
|
||||||
|
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||||
|
return "email_html"
|
||||||
|
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
|
||||||
|
return "generic_html"
|
||||||
|
# .html without clear signals — check for any email header
|
||||||
|
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
|
||||||
|
return "email_html"
|
||||||
|
return "generic_html"
|
||||||
|
|
||||||
|
# Plain text files with email headers
|
||||||
|
if ext in ("", "txt") or not ext:
|
||||||
|
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||||
|
return "email_html"
|
||||||
|
|
||||||
|
# Detect binary content
|
||||||
|
try:
|
||||||
|
raw_content.encode("utf-8")
|
||||||
|
except (UnicodeEncodeError, AttributeError):
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
# Non-text bytes heuristic: high ratio of non-printable chars
|
||||||
|
sample = raw_content[:512]
|
||||||
|
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
|
||||||
|
if len(sample) > 0 and non_printable / len(sample) > 0.1:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Generic fallback handler ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
|
||||||
|
"""Strip HTML tags if present, return text as-is."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
|
||||||
|
except ImportError:
|
||||||
|
# No BeautifulSoup — strip tags with a simple regex
|
||||||
|
text = re.sub(r"<[^>]+>", "", raw_content)
|
||||||
|
|
||||||
|
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
||||||
|
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Dispatch ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
|
||||||
|
"""Dispatch *raw_content* to the handler registered for *content_type*.
|
||||||
|
|
||||||
|
Falls back to the generic handler for unknown types.
|
||||||
|
"""
|
||||||
|
if content_type == "email_html":
|
||||||
|
from app.core.preprocessors.email_html import preprocess_email_html
|
||||||
|
return preprocess_email_html(raw_content)
|
||||||
|
|
||||||
|
return _preprocess_generic(raw_content, content_type)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]
|
||||||
25
app/core/preprocessors/base.py
Normal file
25
app/core/preprocessors/base.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Base types for the preprocessor system."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreprocessResult:
|
||||||
|
"""Output of a preprocessor handler.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
content_type:
|
||||||
|
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
|
||||||
|
clean_text:
|
||||||
|
Human-readable text stripped of markup/binary noise.
|
||||||
|
metadata:
|
||||||
|
Dict of extracted metadata (keys vary by handler).
|
||||||
|
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content_type: str
|
||||||
|
clean_text: str
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
111
app/core/preprocessors/email_html.py
Normal file
111
app/core/preprocessors/email_html.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Preprocessor for email HTML files.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- HTML stripping via BeautifulSoup
|
||||||
|
- Metadata extraction (Subject, From, To, Date)
|
||||||
|
- Thread splitting — isolates the latest reply
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from app.core.preprocessors.base import PreprocessResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ── Thread split markers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Matches patterns like:
|
||||||
|
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
|
||||||
|
# "-----Original Message-----"
|
||||||
|
# "> " (plain-text quote prefix)
|
||||||
|
_THREAD_PATTERNS = [
|
||||||
|
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
|
||||||
|
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
|
||||||
|
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
|
||||||
|
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
|
||||||
|
|
||||||
|
_META_PATTERNS: dict[str, list[re.Pattern]] = {
|
||||||
|
"subject": [
|
||||||
|
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
|
||||||
|
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
"from": [
|
||||||
|
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||||
|
re.compile(r"From:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
"to": [
|
||||||
|
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||||
|
re.compile(r"To:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
"date": [
|
||||||
|
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||||
|
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
|
||||||
|
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_metadata(raw_html: str, text: str) -> dict:
|
||||||
|
"""Extract Subject/From/To/Date from raw HTML or plain text."""
|
||||||
|
metadata: dict[str, str] = {}
|
||||||
|
for field, patterns in _META_PATTERNS.items():
|
||||||
|
for pat in patterns:
|
||||||
|
m = pat.search(raw_html) or pat.search(text)
|
||||||
|
if m:
|
||||||
|
metadata[field] = m.group(1).strip()
|
||||||
|
break
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _split_thread(text: str) -> str:
|
||||||
|
"""Return only the latest message in a threaded email."""
|
||||||
|
earliest_pos: int | None = None
|
||||||
|
for pat in _THREAD_PATTERNS:
|
||||||
|
m = pat.search(text)
|
||||||
|
if m and (earliest_pos is None or m.start() < earliest_pos):
|
||||||
|
earliest_pos = m.start()
|
||||||
|
|
||||||
|
if earliest_pos is not None and earliest_pos > 0:
|
||||||
|
return text[:earliest_pos].strip()
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_email_html(raw_content: str) -> PreprocessResult:
|
||||||
|
"""Strip HTML, extract metadata, split thread from an email HTML file."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup # lazy import — optional dep
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"beautifulsoup4 is required for email_html preprocessing. "
|
||||||
|
"Install it with: pip install beautifulsoup4"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Parse with lxml if available, fall back to html.parser
|
||||||
|
try:
|
||||||
|
soup = BeautifulSoup(raw_content, "lxml")
|
||||||
|
except Exception:
|
||||||
|
soup = BeautifulSoup(raw_content, "html.parser")
|
||||||
|
|
||||||
|
# Remove noise tags
|
||||||
|
for tag in soup(["style", "script", "head", "noscript"]):
|
||||||
|
tag.decompose()
|
||||||
|
|
||||||
|
clean_text = soup.get_text(separator="\n")
|
||||||
|
# Collapse excessive blank lines
|
||||||
|
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
|
||||||
|
|
||||||
|
metadata = _extract_metadata(raw_content, clean_text)
|
||||||
|
latest_message = _split_thread(clean_text)
|
||||||
|
|
||||||
|
return PreprocessResult(
|
||||||
|
content_type="email_html",
|
||||||
|
clean_text=latest_message,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
@@ -25,7 +25,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|||||||
96
app/main.py
96
app/main.py
@@ -4,6 +4,10 @@ import logging
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||||
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
@@ -11,19 +15,88 @@ logging.basicConfig(
|
|||||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||||
|
|
||||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
|
||||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
async def _memory_audit_cron_tick() -> None:
|
||||||
from app.config.settings import settings
|
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
|
||||||
|
import logging # noqa: PLC0415
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
_log.info("memory audit cron tick: starting")
|
||||||
|
try:
|
||||||
|
from app.db import async_session # noqa: PLC0415
|
||||||
|
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
|
||||||
|
from app.models import User # noqa: PLC0415
|
||||||
|
from sqlalchemy import select # noqa: PLC0415
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(select(User.id))
|
||||||
|
user_ids: list[str] = list(result.scalars().all())
|
||||||
|
|
||||||
|
for uid in user_ids:
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
await audit_memory(db, uid)
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
|
||||||
|
|
||||||
|
_log.info("memory audit cron tick: done users=%d", len(user_ids))
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("memory audit cron tick: failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _memory_cron_tick() -> None:
|
||||||
|
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
|
||||||
|
import logging # noqa: PLC0415
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
_log.info("memory cron tick: starting")
|
||||||
|
try:
|
||||||
|
from app.db import async_session # noqa: PLC0415
|
||||||
|
from app.core.memory_maintenance import drain_extraction_queue, mine_proactive_patterns # noqa: PLC0415
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.models import User # noqa: PLC0415
|
||||||
|
from sqlalchemy import select # noqa: PLC0415
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
await drain_extraction_queue(db)
|
||||||
|
|
||||||
|
# mine proactive patterns for every Power+ user
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(select(User.id))
|
||||||
|
user_ids: list[str] = list(result.scalars().all())
|
||||||
|
|
||||||
|
for uid in user_ids:
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
tier = await tier_manager.get_tier(uid, db)
|
||||||
|
if tier_manager.check_feature(tier, "proactive_mining"):
|
||||||
|
await mine_proactive_patterns(db, uid)
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("memory cron tick: mine_proactive_patterns failed user=%s: %s", uid, exc)
|
||||||
|
|
||||||
|
_log.info("memory cron tick: done users=%d", len(user_ids))
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("memory cron tick: failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: initialise DB connection pool and agent registry
|
# Startup: ensure agent tool modules are loaded.
|
||||||
from app.core.agent_registry import registry # noqa: F401 — triggers module load
|
import app.agents # noqa: F401
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
|
scheduler = None
|
||||||
|
if settings.SCHEDULER_ENABLED:
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler # noqa: PLC0415
|
||||||
|
|
||||||
|
scheduler = AsyncIOScheduler()
|
||||||
|
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
||||||
|
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
||||||
|
scheduler.start()
|
||||||
|
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler.shutdown(wait=False)
|
||||||
|
|
||||||
# Shutdown: dispose SQLAlchemy connection pool
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
from app.db import engine
|
from app.db import engine
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
@@ -31,7 +104,7 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Adiuva Cloud API",
|
title="AdiuvAI Cloud API",
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
redoc_url=None,
|
redoc_url=None,
|
||||||
@@ -51,19 +124,14 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
|
from app.api.routes import agents, auth, billing, chat, device_ws, memory
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
app.include_router(plans.router, prefix="/api/v1")
|
|
||||||
app.include_router(storage.router, prefix="/api/v1")
|
|
||||||
app.include_router(vectors.router, prefix="/api/v1")
|
|
||||||
app.include_router(backup.router, prefix="/api/v1")
|
|
||||||
app.include_router(plugins.router, prefix="/api/v1")
|
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
app.include_router(agent_setup.router, prefix="/api/v1")
|
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
app.include_router(memory.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
async def health() -> dict:
|
async def health() -> dict:
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
"""Plugin marketplace package.
|
|
||||||
|
|
||||||
Three service classes introduced in Step 10:
|
|
||||||
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
|
||||||
- ``ReviewQueue`` — approval workflow + security checklist
|
|
||||||
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
|
||||||
"""
|
|
||||||
@@ -1,212 +0,0 @@
|
|||||||
"""Plugin catalog registry backed by PostgreSQL.
|
|
||||||
|
|
||||||
Maintains the authoritative list of plugins, their review status, and
|
|
||||||
aggregate install counts. All data is persisted in the ``plugins`` table.
|
|
||||||
|
|
||||||
Module-level singleton::
|
|
||||||
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from sqlalchemy import select, func
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.models import Plugin
|
|
||||||
from app.schemas import PluginListResponse, PluginManifest
|
|
||||||
|
|
||||||
_PAGE_SIZE = 20
|
|
||||||
|
|
||||||
|
|
||||||
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
|
|
||||||
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
|
|
||||||
try:
|
|
||||||
permissions = json.loads(p.permissions) if p.permissions else []
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
permissions = []
|
|
||||||
return PluginManifest(
|
|
||||||
id=p.id,
|
|
||||||
name=p.name,
|
|
||||||
description=p.description,
|
|
||||||
version=p.version,
|
|
||||||
author=p.author_name,
|
|
||||||
permissions=permissions,
|
|
||||||
category=p.category,
|
|
||||||
price_cents=p.price_cents,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginRegistry:
|
|
||||||
"""PostgreSQL-backed plugin catalog.
|
|
||||||
|
|
||||||
All methods accept an ``AsyncSession`` parameter so the calling route
|
|
||||||
controls the session lifecycle.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ── Queries ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def list_plugins(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
category: str | None = None,
|
|
||||||
query: str | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
sort: Literal["rating", "installs", "newest"] = "newest",
|
|
||||||
) -> PluginListResponse:
|
|
||||||
"""Return a page of approved plugins, optionally filtered and sorted."""
|
|
||||||
base = select(Plugin).where(Plugin.status == "approved")
|
|
||||||
|
|
||||||
if category:
|
|
||||||
base = base.where(Plugin.category == category)
|
|
||||||
if query:
|
|
||||||
pattern = f"%{query}%"
|
|
||||||
base = base.where(
|
|
||||||
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Count
|
|
||||||
count_q = select(func.count()).select_from(base.subquery())
|
|
||||||
total = (await db.execute(count_q)).scalar_one()
|
|
||||||
|
|
||||||
# Sort
|
|
||||||
if sort == "installs":
|
|
||||||
base = base.order_by(Plugin.install_count.desc())
|
|
||||||
elif sort == "rating":
|
|
||||||
base = base.order_by(Plugin.avg_rating.desc())
|
|
||||||
else: # newest
|
|
||||||
base = base.order_by(Plugin.created_at.desc())
|
|
||||||
|
|
||||||
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
|
|
||||||
rows = (await db.execute(base)).scalars().all()
|
|
||||||
|
|
||||||
return PluginListResponse(
|
|
||||||
plugins=[_plugin_to_manifest(r) for r in rows],
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
|
|
||||||
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
p = result.scalar_one_or_none()
|
|
||||||
if p is None:
|
|
||||||
return None
|
|
||||||
return {
|
|
||||||
"manifest": _plugin_to_manifest(p),
|
|
||||||
"status": p.status,
|
|
||||||
"install_count": p.install_count,
|
|
||||||
"avg_rating": p.avg_rating,
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Mutations ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def submit_plugin(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
manifest: PluginManifest,
|
|
||||||
package_s3_key: str,
|
|
||||||
) -> str:
|
|
||||||
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
|
||||||
|
|
||||||
Returns the plugin_id. If a plugin with the same id already exists
|
|
||||||
it is overwritten (re-submission after rejection).
|
|
||||||
"""
|
|
||||||
plugin_id = manifest.id
|
|
||||||
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = existing.scalar_one_or_none()
|
|
||||||
|
|
||||||
if row is not None:
|
|
||||||
row.name = manifest.name
|
|
||||||
row.description = manifest.description
|
|
||||||
row.version = manifest.version
|
|
||||||
row.author_name = manifest.author
|
|
||||||
row.category = manifest.category
|
|
||||||
row.price_cents = manifest.price_cents
|
|
||||||
row.permissions = json.dumps(manifest.permissions)
|
|
||||||
row.status = "pending_review"
|
|
||||||
row.s3_package_key = package_s3_key
|
|
||||||
row.rejection_reason = None
|
|
||||||
else:
|
|
||||||
row = Plugin(
|
|
||||||
id=plugin_id,
|
|
||||||
name=manifest.name,
|
|
||||||
description=manifest.description,
|
|
||||||
version=manifest.version,
|
|
||||||
author_name=manifest.author,
|
|
||||||
category=manifest.category,
|
|
||||||
price_cents=manifest.price_cents,
|
|
||||||
permissions=json.dumps(manifest.permissions),
|
|
||||||
status="pending_review",
|
|
||||||
s3_package_key=package_s3_key,
|
|
||||||
install_count=0,
|
|
||||||
avg_rating=0.0,
|
|
||||||
)
|
|
||||||
db.add(row)
|
|
||||||
await db.commit()
|
|
||||||
return plugin_id
|
|
||||||
|
|
||||||
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
|
|
||||||
"""Set *plugin_id* status to ``'approved'``.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the plugin is not found.
|
|
||||||
"""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
|
||||||
row.status = "approved"
|
|
||||||
row.rejection_reason = None
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
|
|
||||||
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the plugin is not found.
|
|
||||||
"""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
|
||||||
row.status = "rejected"
|
|
||||||
row.rejection_reason = reason
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
|
|
||||||
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is not None:
|
|
||||||
row.install_count = row.install_count + 1
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
|
|
||||||
"""Decrement the install count for *plugin_id*, floored at 0."""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is not None:
|
|
||||||
row.install_count = max(0, row.install_count - 1)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
|
||||||
|
|
||||||
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
|
|
||||||
"""Return all entries with status='pending_review'."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(Plugin).where(Plugin.status == "pending_review")
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"manifest": _plugin_to_manifest(r),
|
|
||||||
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
|
|
||||||
}
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
registry = PluginRegistry()
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
"""Plugin review workflow backed by PostgreSQL.
|
|
||||||
|
|
||||||
Manages the approval queue for newly submitted plugins and enforces a
|
|
||||||
security checklist before any plugin is made visible in the marketplace.
|
|
||||||
|
|
||||||
Module-level singleton::
|
|
||||||
|
|
||||||
from app.marketplace.plugin_review import review_queue
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
from app.models import PluginReview as PluginReviewModel
|
|
||||||
from app.schemas import PluginManifest
|
|
||||||
|
|
||||||
# ── Security policy ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"read:tasks",
|
|
||||||
"write:tasks",
|
|
||||||
"read:projects",
|
|
||||||
"write:projects",
|
|
||||||
"read:notes",
|
|
||||||
"write:notes",
|
|
||||||
"read:timelines",
|
|
||||||
"write:timelines",
|
|
||||||
"read:calendar",
|
|
||||||
"write:calendar",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
|
||||||
|
|
||||||
|
|
||||||
def validate_manifest(manifest: PluginManifest) -> None:
|
|
||||||
"""Enforce the plugin security checklist.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
``ValueError`` on the first violation found. Callers should catch
|
|
||||||
this and return HTTP 422 / reject the submission.
|
|
||||||
|
|
||||||
Checks:
|
|
||||||
1. Plugin id matches ``^[a-z0-9-]+$``
|
|
||||||
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
|
||||||
3. No manifest field contains raw binary data
|
|
||||||
"""
|
|
||||||
if not _PLUGIN_ID_RE.match(manifest.id):
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid plugin id format: '{manifest.id}'. "
|
|
||||||
"Only lowercase letters, digits, and hyphens are allowed."
|
|
||||||
)
|
|
||||||
|
|
||||||
for perm in manifest.permissions:
|
|
||||||
if perm not in ALLOWED_PERMISSIONS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown permission: '{perm}'. "
|
|
||||||
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for field_name, value in manifest.model_dump().items():
|
|
||||||
if isinstance(value, (bytes, bytearray)):
|
|
||||||
raise ValueError(
|
|
||||||
f"Binary content is not allowed in manifest field '{field_name}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReviewQueue:
|
|
||||||
"""Approval queue for pending plugin submissions.
|
|
||||||
|
|
||||||
Delegates status changes to the shared ``PluginRegistry`` singleton.
|
|
||||||
Review records are persisted in the ``plugin_reviews`` table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
|
|
||||||
"""Return all plugins currently awaiting review.
|
|
||||||
|
|
||||||
Each item is ``{plugin_id, manifest, submitted_at}``.
|
|
||||||
"""
|
|
||||||
entries = await registry.get_pending_entries(db)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"plugin_id": e["manifest"].id,
|
|
||||||
"manifest": e["manifest"],
|
|
||||||
"submitted_at": e["submitted_at"],
|
|
||||||
}
|
|
||||||
for e in entries
|
|
||||||
]
|
|
||||||
|
|
||||||
async def submit_review(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
plugin_id: str,
|
|
||||||
reviewer_id: str,
|
|
||||||
decision: Literal["approved", "rejected"],
|
|
||||||
notes: str = "",
|
|
||||||
) -> None:
|
|
||||||
"""Record a review decision and update the plugin's status.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
``KeyError`` if *plugin_id* is not found in the registry.
|
|
||||||
"""
|
|
||||||
if decision == "approved":
|
|
||||||
await registry.approve_plugin(db, plugin_id)
|
|
||||||
else:
|
|
||||||
await registry.reject_plugin(db, plugin_id, reason=notes)
|
|
||||||
|
|
||||||
review = PluginReviewModel(
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
reviewer_id=reviewer_id,
|
|
||||||
decision=decision,
|
|
||||||
notes=notes,
|
|
||||||
)
|
|
||||||
db.add(review)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
review_queue = ReviewQueue()
|
|
||||||
@@ -1,233 +0,0 @@
|
|||||||
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
|
|
||||||
|
|
||||||
Records every plugin installation as a revenue event and facilitates
|
|
||||||
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
|
|
||||||
in the ``revenue_events`` table.
|
|
||||||
|
|
||||||
Module-level singleton::
|
|
||||||
|
|
||||||
from app.marketplace.revenue_share import revenue_share
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import stripe as stripe_lib
|
|
||||||
from sqlalchemy import extract, func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
from app.models import Plugin, RevenueEvent
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ── Revenue split constants ───────────────────────────────────────────
|
|
||||||
|
|
||||||
DEVELOPER_SHARE: float = 0.70
|
|
||||||
PLATFORM_SHARE: float = 0.30
|
|
||||||
|
|
||||||
|
|
||||||
class RevenueShare:
|
|
||||||
"""Records installation revenue events and coordinates developer payouts.
|
|
||||||
|
|
||||||
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
|
||||||
is not configured, consistent with the rest of the billing layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _stripe_configured() -> bool:
|
|
||||||
return bool(settings.STRIPE_SECRET_KEY)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _stripe() -> Any:
|
|
||||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
|
||||||
return stripe_lib
|
|
||||||
|
|
||||||
# ── Core operations ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def record_install(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
plugin_id: str,
|
|
||||||
user_id: str,
|
|
||||||
amount_cents: int,
|
|
||||||
) -> None:
|
|
||||||
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
|
||||||
|
|
||||||
For free plugins (``amount_cents == 0``) no payment is initiated but
|
|
||||||
the event is still recorded for analytics.
|
|
||||||
|
|
||||||
For paid plugins the developer receives 70 % via a Stripe Connect
|
|
||||||
destination charge. If Stripe is not configured or the charge fails
|
|
||||||
the installation still succeeds (the event is recorded and the install
|
|
||||||
count is incremented) — a warning is logged for monitoring.
|
|
||||||
"""
|
|
||||||
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
|
||||||
stripe_transfer_id: str | None = None
|
|
||||||
|
|
||||||
if amount_cents > 0 and self._stripe_configured():
|
|
||||||
# Look up the plugin's author Stripe account from the DB
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
plugin_row = result.scalar_one_or_none()
|
|
||||||
developer_stripe_account: str | None = None
|
|
||||||
if plugin_row and plugin_row.author_id:
|
|
||||||
# Future: look up user.stripe_connect_account_id
|
|
||||||
developer_stripe_account = None # no real account yet
|
|
||||||
|
|
||||||
if developer_stripe_account:
|
|
||||||
try:
|
|
||||||
s = self._stripe()
|
|
||||||
transfer = s.Transfer.create(
|
|
||||||
amount=developer_share_cents,
|
|
||||||
currency="eur",
|
|
||||||
destination=developer_stripe_account,
|
|
||||||
description=f"Revenue share for plugin {plugin_id}",
|
|
||||||
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
|
||||||
)
|
|
||||||
stripe_transfer_id = transfer["id"]
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"Stripe Connect transfer failed for plugin %s: %s",
|
|
||||||
plugin_id,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"No Stripe account on file for plugin %s developer; "
|
|
||||||
"skipping transfer.",
|
|
||||||
plugin_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
event = RevenueEvent(
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
user_id=user_id,
|
|
||||||
amount_cents=amount_cents,
|
|
||||||
developer_share_cents=developer_share_cents,
|
|
||||||
stripe_transfer_id=stripe_transfer_id,
|
|
||||||
)
|
|
||||||
db.add(event)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
await registry.record_install(db, plugin_id)
|
|
||||||
|
|
||||||
async def get_earnings(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
developer_id: str,
|
|
||||||
period: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Return aggregated earnings for *developer_id*.
|
|
||||||
|
|
||||||
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
|
||||||
|
|
||||||
Returns::
|
|
||||||
|
|
||||||
{
|
|
||||||
"developer_id": str,
|
|
||||||
"period": str | None,
|
|
||||||
"total_installs": int,
|
|
||||||
"total_revenue_cents": int,
|
|
||||||
"developer_share_cents": int,
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Find plugin ids belonging to this developer (by author_name match)
|
|
||||||
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
|
|
||||||
plugin_result = await db.execute(plugin_q)
|
|
||||||
developer_plugin_ids = [row[0] for row in plugin_result.all()]
|
|
||||||
|
|
||||||
if not developer_plugin_ids:
|
|
||||||
return {
|
|
||||||
"developer_id": developer_id,
|
|
||||||
"period": period,
|
|
||||||
"total_installs": 0,
|
|
||||||
"total_revenue_cents": 0,
|
|
||||||
"developer_share_cents": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
query = select(
|
|
||||||
func.count().label("total_installs"),
|
|
||||||
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
|
|
||||||
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
|
|
||||||
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
|
|
||||||
|
|
||||||
if period:
|
|
||||||
# Filter by YYYY-MM: extract year and month from created_at
|
|
||||||
try:
|
|
||||||
year, month = period.split("-")
|
|
||||||
query = query.where(
|
|
||||||
extract("year", RevenueEvent.created_at) == int(year),
|
|
||||||
extract("month", RevenueEvent.created_at) == int(month),
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
pass # invalid period format — return all
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
row = result.one()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"developer_id": developer_id,
|
|
||||||
"period": period,
|
|
||||||
"total_installs": row.total_installs,
|
|
||||||
"total_revenue_cents": row.total_revenue,
|
|
||||||
"developer_share_cents": row.dev_share,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
|
|
||||||
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
|
||||||
|
|
||||||
Marks processed events with ``paid_at`` timestamp.
|
|
||||||
Stubs gracefully when Stripe is not configured.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
year, month = period.split("-")
|
|
||||||
year_int, month_int = int(year), int(month)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("Invalid period format: %s", period)
|
|
||||||
return
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(RevenueEvent).where(
|
|
||||||
RevenueEvent.plugin_id == plugin_id,
|
|
||||||
RevenueEvent.paid_at.is_(None),
|
|
||||||
extract("year", RevenueEvent.created_at) == year_int,
|
|
||||||
extract("month", RevenueEvent.created_at) == month_int,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
unpaid = list(result.scalars().all())
|
|
||||||
|
|
||||||
total_dev_share = sum(e.developer_share_cents for e in unpaid)
|
|
||||||
if total_dev_share <= 0 or not unpaid:
|
|
||||||
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._stripe_configured():
|
|
||||||
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
plugin_row = plugin_result.scalar_one_or_none()
|
|
||||||
developer_stripe_account: str | None = None # Future: fetch from DB
|
|
||||||
if plugin_row and developer_stripe_account:
|
|
||||||
try:
|
|
||||||
s = self._stripe()
|
|
||||||
s.Transfer.create(
|
|
||||||
amount=total_dev_share,
|
|
||||||
currency="eur",
|
|
||||||
destination=developer_stripe_account,
|
|
||||||
description=f"Payout for plugin {plugin_id} period {period}",
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
paid_ts = datetime.now(timezone.utc)
|
|
||||||
for event in unpaid:
|
|
||||||
event.paid_at = paid_ts
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
revenue_share = RevenueShare()
|
|
||||||
281
app/models.py
281
app/models.py
@@ -1,23 +1,20 @@
|
|||||||
"""SQLAlchemy ORM models for all persistent tables.
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
Only auth, billing, storage metadata, and marketplace data live here.
|
Only auth, billing, agent config, and memory data live here.
|
||||||
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
User content (notes, tasks, etc.) lives exclusively on the client.
|
||||||
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
|
||||||
|
|
||||||
Table inventory:
|
Table inventory:
|
||||||
users — account credentials + tier
|
users — account credentials + tier
|
||||||
refresh_tokens — hashed refresh token store
|
refresh_tokens — hashed refresh token store
|
||||||
subscriptions — Stripe subscription records
|
subscriptions — Stripe subscription records
|
||||||
storage_records — S3 blob metadata (no plaintext)
|
local_agent_configs — per-device batch agent configs
|
||||||
backup_metadata — encrypted backup manifests
|
cloud_agent_configs — OAuth-backed cloud agent configs
|
||||||
plugins — marketplace plugin catalog
|
agent_run_logs — execution history for all agents
|
||||||
plugin_installations — per-user install records
|
|
||||||
plugin_reviews — admin review decisions
|
|
||||||
revenue_events — Stripe Connect 70/30 split ledger
|
|
||||||
memory_core — per-user persistent key/value preferences (encrypted)
|
memory_core — per-user persistent key/value preferences (encrypted)
|
||||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
memory_episodic — per-user session summaries (encrypted)
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
memory_proactive — per-user behavioral patterns (encrypted)
|
memory_proactive — per-user behavioral patterns (encrypted)
|
||||||
|
memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -25,8 +22,8 @@ from __future__ import annotations
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
BigInteger,
|
|
||||||
Boolean,
|
Boolean,
|
||||||
DateTime,
|
DateTime,
|
||||||
Enum,
|
Enum,
|
||||||
@@ -34,9 +31,9 @@ from sqlalchemy import (
|
|||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
JSON,
|
JSON,
|
||||||
|
LargeBinary,
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
UniqueConstraint,
|
|
||||||
Uuid,
|
Uuid,
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
@@ -58,8 +55,6 @@ def _now() -> datetime:
|
|||||||
# ── Enum types ────────────────────────────────────────────────────────────
|
# ── Enum types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||||
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
|
||||||
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
|
||||||
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
||||||
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
||||||
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
||||||
@@ -77,7 +72,8 @@ class User(Base):
|
|||||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||||
@@ -86,6 +82,9 @@ class User(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
|
onboarding_completed_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True, default=None
|
||||||
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
)
|
)
|
||||||
@@ -96,6 +95,9 @@ class User(Base):
|
|||||||
subscription: Mapped[Subscription | None] = relationship(
|
subscription: Mapped[Subscription | None] = relationship(
|
||||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(Base):
|
class RefreshToken(Base):
|
||||||
@@ -116,6 +118,25 @@ class RefreshToken(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccount(Base):
|
||||||
|
__tablename__ = "oauth_accounts"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
provider_user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
provider_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="oauth_accounts")
|
||||||
|
|
||||||
|
|
||||||
class Subscription(Base):
|
class Subscription(Base):
|
||||||
__tablename__ = "subscriptions"
|
__tablename__ = "subscriptions"
|
||||||
|
|
||||||
@@ -137,151 +158,6 @@ class Subscription(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="subscription")
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
class StorageRecord(Base):
|
|
||||||
__tablename__ = "storage_records"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
|
||||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
|
||||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
||||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BackupMetadata(Base):
|
|
||||||
__tablename__ = "backup_metadata"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
|
||||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
|
||||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
||||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Plugin(Base):
|
|
||||||
__tablename__ = "plugins"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
|
||||||
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
|
||||||
# nullable until developer account system is built
|
|
||||||
author_id: Mapped[str | None] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
|
||||||
)
|
|
||||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
|
||||||
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
|
||||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
|
||||||
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
|
||||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
|
||||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
|
||||||
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
submitted_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
installations: Mapped[list[PluginInstallation]] = relationship(
|
|
||||||
back_populates="plugin", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
reviews: Mapped[list[PluginReview]] = relationship(
|
|
||||||
back_populates="plugin", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
|
||||||
back_populates="plugin", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginInstallation(Base):
|
|
||||||
__tablename__ = "plugin_installations"
|
|
||||||
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
plugin_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
installed_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
|
||||||
|
|
||||||
|
|
||||||
class PluginReview(Base):
|
|
||||||
__tablename__ = "plugin_reviews"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
plugin_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
reviewer_id: Mapped[str | None] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
|
||||||
)
|
|
||||||
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
|
||||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
reviewed_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
|
||||||
|
|
||||||
|
|
||||||
class RevenueEvent(Base):
|
|
||||||
__tablename__ = "revenue_events"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
plugin_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
||||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfig(Base):
|
class LocalAgentConfig(Base):
|
||||||
__tablename__ = "local_agent_configs"
|
__tablename__ = "local_agent_configs"
|
||||||
|
|
||||||
@@ -296,6 +172,7 @@ class LocalAgentConfig(Base):
|
|||||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
@@ -425,8 +302,8 @@ class MemoryAssociative(Base):
|
|||||||
nullable=False, index=True,
|
nullable=False, index=True,
|
||||||
)
|
)
|
||||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
# vector(1536) via pgvector; SQLite tests use NULL embeddings so no dialect issue.
|
||||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
embedding: Mapped[list | None] = mapped_column(Vector(1536), nullable=True)
|
||||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
@@ -474,3 +351,85 @@ class MemoryProactive(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionQueue(Base):
|
||||||
|
"""Batch extraction queue for Free-tier users (Phase 2).
|
||||||
|
|
||||||
|
Pro/Power/Team users get realtime asyncio.create_task() extraction.
|
||||||
|
Free users get a queue row here; a daily cron (Phase 5) drains it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "extraction_queue"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
episode_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), nullable=True,
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRelation(Base):
|
||||||
|
"""Per-user entity/relation graph row (Mem0g-light, Phase 3).
|
||||||
|
|
||||||
|
subject_label/object_label are plaintext entity identifiers (not user content).
|
||||||
|
notes_encrypted is optional Fernet-encrypted per-user commentary.
|
||||||
|
confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_relations"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
subject_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
|
subject_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
|
predicate: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
object_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
|
object_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7)
|
||||||
|
source_episode_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False),
|
||||||
|
ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
last_confirmed_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
"""Plugin marketplace catalog entry."""
|
||||||
|
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default="pending")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|||||||
311
app/schemas.py
311
app/schemas.py
@@ -30,6 +30,16 @@ class UserProfile(BaseModel):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
surname: str | None = None
|
surname: str | None = None
|
||||||
tier: BillingTier
|
tier: BillingTier
|
||||||
|
avatar_url: str | None = None
|
||||||
|
has_password: bool = True
|
||||||
|
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
|
||||||
|
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountInfo(BaseModel):
|
||||||
|
provider: str
|
||||||
|
provider_email: str | None = None
|
||||||
|
created_at: int # epoch ms
|
||||||
|
|
||||||
|
|
||||||
# ── Chat ─────────────────────────────────────────────────────────────
|
# ── Chat ─────────────────────────────────────────────────────────────
|
||||||
@@ -41,123 +51,13 @@ class ChatContext(BaseModel):
|
|||||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class PlanAction(BaseModel):
|
|
||||||
type: Literal[
|
|
||||||
"create_record",
|
|
||||||
"update_record",
|
|
||||||
"delete_record",
|
|
||||||
"index_document",
|
|
||||||
"send_notification",
|
|
||||||
]
|
|
||||||
table: str | None = None
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
context: ChatContext = Field(default_factory=ChatContext)
|
context: ChatContext = Field(default_factory=ChatContext)
|
||||||
execution_mode: Literal["direct", "plan"] = "direct"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
response: str
|
response: str
|
||||||
actions: list[PlanAction] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plans ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class PlanStep(BaseModel):
|
|
||||||
action: str
|
|
||||||
prompt_template: str | None = None
|
|
||||||
variables: dict[str, Any] | None = None
|
|
||||||
data_from_step: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlan(BaseModel):
|
|
||||||
agent: str
|
|
||||||
steps: list[PlanStep] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Backup ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class BackupMetadata(BaseModel):
|
|
||||||
version: int
|
|
||||||
timestamp: int
|
|
||||||
checksum: str
|
|
||||||
chunk_count: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
|
||||||
|
|
||||||
class StorageRecord(BaseModel):
|
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
table: str
|
|
||||||
blob: bytes
|
|
||||||
checksum: str
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
class StorageRecordCreate(BaseModel):
|
|
||||||
table: str
|
|
||||||
blob: bytes
|
|
||||||
checksum: str
|
|
||||||
|
|
||||||
|
|
||||||
class StorageRecordUpdate(BaseModel):
|
|
||||||
blob: bytes
|
|
||||||
checksum: str
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
|
||||||
|
|
||||||
class VectorItem(BaseModel):
|
|
||||||
id: str
|
|
||||||
blob: bytes # encrypted vector + metadata — backend never decrypts
|
|
||||||
checksum: str
|
|
||||||
|
|
||||||
|
|
||||||
class VectorUpsertRequest(BaseModel):
|
|
||||||
vectors: list[VectorItem]
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchRequest(BaseModel):
|
|
||||||
query_blob: bytes # encrypted query — backend never decrypts
|
|
||||||
top_k: int = 10
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchResult(BaseModel):
|
|
||||||
id: str
|
|
||||||
score: float
|
|
||||||
blob: bytes
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchResponse(BaseModel):
|
|
||||||
results: list[VectorSearchResult]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plugin Marketplace ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class PluginManifest(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
version: str
|
|
||||||
author: str
|
|
||||||
permissions: list[str]
|
|
||||||
category: str
|
|
||||||
price_cents: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class PluginListResponse(BaseModel):
|
|
||||||
plugins: list[PluginManifest]
|
|
||||||
total: int
|
|
||||||
page: int
|
|
||||||
|
|
||||||
|
|
||||||
class PluginInstallRequest(BaseModel):
|
|
||||||
plugin_id: str
|
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||||
@@ -170,21 +70,23 @@ class WsFrameType(str, Enum):
|
|||||||
tool_result = "tool_result"
|
tool_result = "tool_result"
|
||||||
final = "final"
|
final = "final"
|
||||||
ping = "ping"
|
ping = "ping"
|
||||||
agent_run = "agent_run"
|
|
||||||
agent_data = "agent_data"
|
|
||||||
agent_complete = "agent_complete"
|
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
floating_request = "floating_request"
|
floating_request = "floating_request"
|
||||||
stream_start = "stream_start"
|
stream_start = "stream_start"
|
||||||
stream_text = "stream_text"
|
stream_text = "stream_text"
|
||||||
stream_block = "stream_block"
|
|
||||||
stream_end = "stream_end"
|
stream_end = "stream_end"
|
||||||
floating_domain = "floating_domain"
|
floating_domain = "floating_domain"
|
||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
|
# ── v4 journey frame types ────────────────────────────────────────
|
||||||
|
journey_start = "journey_start"
|
||||||
|
journey_message = "journey_message"
|
||||||
|
journey_reply = "journey_reply"
|
||||||
|
# ── v5 brief frame types ──────────────────────────────────────────
|
||||||
|
brief_request = "brief_request"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -237,31 +139,6 @@ class WsDeviceHello(BaseModel):
|
|||||||
agent_ids: list[str] = Field(default_factory=list)
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class WsAgentRun(BaseModel):
|
|
||||||
"""Server → Client: trigger an agent run on the connected device."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
|
||||||
run_id: str
|
|
||||||
agent_id: str
|
|
||||||
config: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentData(BaseModel):
|
|
||||||
"""Client → Server: files read by the local agent."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
|
||||||
run_id: str
|
|
||||||
files: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentComplete(BaseModel):
|
|
||||||
"""Client → Server: Electron signals it has finished reading files."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
|
||||||
run_id: str
|
|
||||||
files_read: int
|
|
||||||
errors: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
@@ -288,6 +165,16 @@ class WsFloatingRequest(BaseModel):
|
|||||||
scope: WsFloatingScope
|
scope: WsFloatingScope
|
||||||
|
|
||||||
|
|
||||||
|
class WsBriefRequest(BaseModel):
|
||||||
|
"""Client → Server: Request a plain-text brief (home or project)."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.brief_request] = WsFrameType.brief_request
|
||||||
|
request_id: str | None = None
|
||||||
|
session_id: str | None = None
|
||||||
|
mode: Literal["home", "project"]
|
||||||
|
project_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsStreamStart(BaseModel):
|
class WsStreamStart(BaseModel):
|
||||||
"""Server → Client: signals start of a streaming response."""
|
"""Server → Client: signals start of a streaming response."""
|
||||||
|
|
||||||
@@ -303,21 +190,20 @@ class WsStreamText(BaseModel):
|
|||||||
chunk: str
|
chunk: str
|
||||||
|
|
||||||
|
|
||||||
class WsStreamBlock(BaseModel):
|
|
||||||
"""Server → Client: structured block (chart, table, entity, timeline)."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.stream_block] = WsFrameType.stream_block
|
|
||||||
request_id: str
|
|
||||||
block_type: Literal["chart", "entity_ref", "table", "timeline"]
|
|
||||||
data: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class WsStreamEnd(BaseModel):
|
class WsStreamEnd(BaseModel):
|
||||||
"""Server → Client: signals end of a streaming response."""
|
"""Server → Client: signals end of a streaming response."""
|
||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsDomain(BaseModel):
|
||||||
|
"""Structured floating domain payload for UI routing decisions."""
|
||||||
|
|
||||||
|
type: Literal["task", "timeline", "project", "node"]
|
||||||
|
id: str | None = None
|
||||||
|
section: Literal["task", "timeline", "note"] | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class WsFloatingDomain(BaseModel):
|
||||||
@@ -325,7 +211,28 @@ class WsFloatingDomain(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
request_id: str
|
request_id: str
|
||||||
domain: Literal["tasks", "timelines", "notes", "projects"]
|
domain: WsDomain
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent Config V2 ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class ContentTypeConfig(BaseModel):
|
||||||
|
"""Per-type extraction config produced by the journey chatbot."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
label: str = ""
|
||||||
|
detection_hint: str = ""
|
||||||
|
preprocessing: str = "generic" # handler name: "email_html", "plain_text", ...
|
||||||
|
extraction_prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Structured agent configuration (replaces freeform prompt_template)."""
|
||||||
|
|
||||||
|
content_types: list[ContentTypeConfig] = []
|
||||||
|
global_rules: list[str] = []
|
||||||
|
data_types: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
@@ -334,84 +241,29 @@ class AgentCatalogItem(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
config_schema: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local Agent Config ────────────────────────────────────────────────
|
class AgentCreationCheckRequest(BaseModel):
|
||||||
|
active_agents: int = Field(ge=0, default=0)
|
||||||
class LocalAgentConfigCreate(BaseModel):
|
|
||||||
name: str
|
|
||||||
device_id: str
|
|
||||||
directory_paths: list[str]
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
file_extensions: list[str]
|
|
||||||
schedule_cron: str
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigUpdate(BaseModel):
|
class AgentCreationCheckResponse(BaseModel):
|
||||||
name: str | None = None
|
allowed: bool
|
||||||
device_id: str | None = None
|
tier: BillingTier
|
||||||
directory_paths: list[str] | None = None
|
active_agents: int
|
||||||
data_types: list[str] | None = None
|
limit: int
|
||||||
prompt_template: str | None = None
|
|
||||||
file_extensions: list[str] | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigResponse(BaseModel):
|
class AgentTriggerRequest(BaseModel):
|
||||||
id: str
|
directory: str = Field(min_length=1)
|
||||||
name: str
|
device_id: str = Field(default="")
|
||||||
device_id: str
|
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||||
directory_paths: list[str]
|
what_to_extract: list[str] = Field(min_length=1)
|
||||||
data_types: list[str]
|
batch_interval: str = Field(min_length=1)
|
||||||
prompt_template: str
|
custom_agent_prompt: str | None = None
|
||||||
file_extensions: list[str]
|
agent_config: dict | None = None
|
||||||
schedule_cron: str
|
active_agents: int = Field(ge=0, default=0)
|
||||||
enabled: bool
|
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Agent Config ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class CloudAgentConfigCreate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
oauth_token_encrypted: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigUpdate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"] | None = None
|
|
||||||
name: str | None = None
|
|
||||||
data_types: list[str] | None = None
|
|
||||||
prompt_template: str | None = None
|
|
||||||
oauth_token_encrypted: str | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigResponse(BaseModel):
|
|
||||||
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None
|
|
||||||
enabled: bool
|
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
@@ -430,18 +282,3 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
class JourneyStartRequest(BaseModel):
|
|
||||||
agent_type: Literal["local", "cloud"]
|
|
||||||
agent_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyMessageRequest(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyResponse(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
done: bool
|
|
||||||
prompt_template: str | None = None
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
"""S3-backed store for E2E-encrypted blobs.
|
|
||||||
|
|
||||||
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
|
||||||
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class BlobStore:
|
|
||||||
"""Thin wrapper around boto3 S3.
|
|
||||||
|
|
||||||
All blobs must be E2E encrypted by the client before upload.
|
|
||||||
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
|
||||||
but cannot decrypt the inner client-side payload.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _client(self) -> Any:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"region_name": settings.S3_REGION,
|
|
||||||
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
|
|
||||||
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
|
|
||||||
}
|
|
||||||
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
|
|
||||||
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
|
|
||||||
return boto3.client("s3", **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _key(user_id: str, table: str, record_id: str) -> str:
|
|
||||||
return f"{user_id}/{table}/{record_id}"
|
|
||||||
|
|
||||||
async def upload(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
table: str,
|
|
||||||
record_id: str,
|
|
||||||
blob: bytes,
|
|
||||||
checksum: str,
|
|
||||||
) -> str:
|
|
||||||
"""Store *blob* in S3 and return the S3 key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Owner of the blob (used as key prefix).
|
|
||||||
table: Logical table name (e.g. ``"tasks"``).
|
|
||||||
record_id: Record UUID.
|
|
||||||
blob: Raw bytes (pre-encrypted by client).
|
|
||||||
checksum: SHA-256 hex digest supplied by the client; stored as
|
|
||||||
object metadata for download-time verification.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The S3 key under which the blob was stored.
|
|
||||||
"""
|
|
||||||
key = self._key(user_id, table, record_id)
|
|
||||||
self._client().put_object(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Key=key,
|
|
||||||
Body=blob,
|
|
||||||
ServerSideEncryption="AES256", # SSE-S3 at rest
|
|
||||||
Metadata={"checksum": checksum},
|
|
||||||
)
|
|
||||||
return key
|
|
||||||
|
|
||||||
async def download(self, user_id: str, s3_key: str) -> bytes:
|
|
||||||
"""Retrieve the blob stored at *s3_key*.
|
|
||||||
|
|
||||||
*user_id* is retained in the signature so higher-level code can
|
|
||||||
enforce ownership without re-parsing the key.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
|
||||||
object does not exist.
|
|
||||||
"""
|
|
||||||
response = self._client().get_object(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Key=s3_key,
|
|
||||||
)
|
|
||||||
return response["Body"].read()
|
|
||||||
|
|
||||||
async def delete(self, user_id: str, s3_key: str) -> None:
|
|
||||||
"""Delete the object at *s3_key*.
|
|
||||||
|
|
||||||
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
|
||||||
not exist.
|
|
||||||
"""
|
|
||||||
self._client().delete_object(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Key=s3_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
|
||||||
"""Return all S3 keys for a given user + table combination.
|
|
||||||
|
|
||||||
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
|
||||||
"""
|
|
||||||
prefix = f"{user_id}/{table}/"
|
|
||||||
response = self._client().list_objects_v2(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Prefix=prefix,
|
|
||||||
)
|
|
||||||
return [obj["Key"] for obj in response.get("Contents", [])]
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
"""Integrity verification only — the backend NEVER decrypts user data."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
|
|
||||||
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
|
||||||
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
|
||||||
|
|
||||||
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
|
||||||
timing-based side-channel attacks.
|
|
||||||
"""
|
|
||||||
computed = hashlib.sha256(blob).hexdigest()
|
|
||||||
return hmac.compare_digest(computed, checksum)
|
|
||||||
|
|
||||||
|
|
||||||
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
|
||||||
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
|
||||||
|
|
||||||
Call this before storing or forwarding any client-provided blob.
|
|
||||||
The backend never holds decryption keys — this check only verifies
|
|
||||||
that the opaque bytes arrived intact.
|
|
||||||
"""
|
|
||||||
if not verify_checksum(blob, checksum):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Checksum mismatch: blob integrity check failed",
|
|
||||||
)
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
|
||||||
|
|
||||||
Vectors are pre-encrypted blobs from the client. The backend stores them
|
|
||||||
alongside a deterministic 32-dim float representation derived from the blob's
|
|
||||||
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
|
||||||
is a known trade-off documented in the backend plan.
|
|
||||||
|
|
||||||
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
|
||||||
``user_id`` payload field on a shared collection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pinecone import Pinecone
|
|
||||||
from qdrant_client import QdrantClient
|
|
||||||
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.schemas import VectorItem, VectorSearchResult
|
|
||||||
|
|
||||||
_QDRANT_COLLECTION = "adiuva_vectors"
|
|
||||||
|
|
||||||
|
|
||||||
def _blob_to_vector(blob: bytes) -> list[float]:
|
|
||||||
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
|
||||||
|
|
||||||
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
|
||||||
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
|
||||||
semantic meaning on encrypted data.
|
|
||||||
"""
|
|
||||||
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
|
||||||
|
|
||||||
|
|
||||||
class VectorStore:
|
|
||||||
"""Thin wrapper around Pinecone or Qdrant.
|
|
||||||
|
|
||||||
The backend to use is selected at runtime:
|
|
||||||
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
|
||||||
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _use_pinecone(self) -> bool:
|
|
||||||
return bool(settings.PINECONE_API_KEY)
|
|
||||||
|
|
||||||
# ── Pinecone helpers ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _pinecone_index(self) -> Any:
|
|
||||||
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
|
||||||
return pc.Index(settings.PINECONE_INDEX)
|
|
||||||
|
|
||||||
# ── Qdrant helpers ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _qdrant_client(self) -> Any:
|
|
||||||
return QdrantClient(
|
|
||||||
url=settings.QDRANT_URL,
|
|
||||||
api_key=settings.QDRANT_API_KEY or None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
|
||||||
"""Store encrypted vectors in the backend.
|
|
||||||
|
|
||||||
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
|
||||||
so it can be returned verbatim during search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Used as Pinecone namespace or Qdrant payload field.
|
|
||||||
vectors: List of encrypted vector items from the client.
|
|
||||||
"""
|
|
||||||
if self._use_pinecone():
|
|
||||||
await self._pinecone_upsert(user_id, vectors)
|
|
||||||
else:
|
|
||||||
await self._qdrant_upsert(user_id, vectors)
|
|
||||||
|
|
||||||
async def search(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
query_blob: bytes,
|
|
||||||
top_k: int,
|
|
||||||
) -> list[VectorSearchResult]:
|
|
||||||
"""Query the vector store and return encrypted result blobs.
|
|
||||||
|
|
||||||
The query vector is derived from *query_blob* using the same
|
|
||||||
deterministic mapping as upsert.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Scopes the search to this user's namespace.
|
|
||||||
query_blob: Encrypted query from the client.
|
|
||||||
top_k: Maximum number of results to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
|
||||||
"""
|
|
||||||
if self._use_pinecone():
|
|
||||||
return await self._pinecone_search(user_id, query_blob, top_k)
|
|
||||||
return await self._qdrant_search(user_id, query_blob, top_k)
|
|
||||||
|
|
||||||
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
|
||||||
"""Remove vectors by ID, scoped to *user_id*.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Namespace / payload filter to prevent cross-user deletion.
|
|
||||||
vector_ids: List of vector IDs to remove.
|
|
||||||
"""
|
|
||||||
if self._use_pinecone():
|
|
||||||
await self._pinecone_delete(user_id, vector_ids)
|
|
||||||
else:
|
|
||||||
await self._qdrant_delete(user_id, vector_ids)
|
|
||||||
|
|
||||||
# ── Pinecone implementation ───────────────────────────────────────
|
|
||||||
|
|
||||||
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
|
||||||
index = self._pinecone_index()
|
|
||||||
records = [
|
|
||||||
{
|
|
||||||
"id": v.id,
|
|
||||||
"values": _blob_to_vector(v.blob),
|
|
||||||
"metadata": {
|
|
||||||
"blob": base64.b64encode(v.blob).decode(),
|
|
||||||
"checksum": v.checksum,
|
|
||||||
"user_id": user_id,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for v in vectors
|
|
||||||
]
|
|
||||||
index.upsert(vectors=records, namespace=user_id)
|
|
||||||
|
|
||||||
async def _pinecone_search(
|
|
||||||
self, user_id: str, query_blob: bytes, top_k: int
|
|
||||||
) -> list[VectorSearchResult]:
|
|
||||||
index = self._pinecone_index()
|
|
||||||
query_vector = _blob_to_vector(query_blob)
|
|
||||||
response = index.query(
|
|
||||||
vector=query_vector,
|
|
||||||
top_k=top_k,
|
|
||||||
namespace=user_id,
|
|
||||||
include_metadata=True,
|
|
||||||
)
|
|
||||||
results: list[VectorSearchResult] = []
|
|
||||||
for match in response.get("matches", []):
|
|
||||||
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
|
||||||
results.append(
|
|
||||||
VectorSearchResult(
|
|
||||||
id=match["id"],
|
|
||||||
score=match["score"],
|
|
||||||
blob=blob_bytes,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
|
||||||
index = self._pinecone_index()
|
|
||||||
index.delete(ids=vector_ids, namespace=user_id)
|
|
||||||
|
|
||||||
# ── Qdrant implementation ─────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
|
||||||
client = self._qdrant_client()
|
|
||||||
points = [
|
|
||||||
PointStruct(
|
|
||||||
id=v.id,
|
|
||||||
vector=_blob_to_vector(v.blob),
|
|
||||||
payload={
|
|
||||||
"blob": base64.b64encode(v.blob).decode(),
|
|
||||||
"checksum": v.checksum,
|
|
||||||
"user_id": user_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for v in vectors
|
|
||||||
]
|
|
||||||
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
|
||||||
|
|
||||||
async def _qdrant_search(
|
|
||||||
self, user_id: str, query_blob: bytes, top_k: int
|
|
||||||
) -> list[VectorSearchResult]:
|
|
||||||
client = self._qdrant_client()
|
|
||||||
query_vector = _blob_to_vector(query_blob)
|
|
||||||
hits = client.search(
|
|
||||||
collection_name=_QDRANT_COLLECTION,
|
|
||||||
query_vector=query_vector,
|
|
||||||
query_filter=Filter(
|
|
||||||
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
|
||||||
),
|
|
||||||
limit=top_k,
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
VectorSearchResult(
|
|
||||||
id=str(hit.id),
|
|
||||||
score=hit.score,
|
|
||||||
blob=base64.b64decode(hit.payload["blob"]),
|
|
||||||
)
|
|
||||||
for hit in hits
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
|
||||||
client = self._qdrant_client()
|
|
||||||
client.delete(
|
|
||||||
collection_name=_QDRANT_COLLECTION,
|
|
||||||
points_selector=PointIdsList(points=vector_ids),
|
|
||||||
)
|
|
||||||
@@ -7,7 +7,7 @@ services:
|
|||||||
- path: .env
|
- path: .env
|
||||||
required: false
|
required: false
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
|
||||||
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
||||||
volumes:
|
volumes:
|
||||||
- copilot_tokens:/root/.config/litellm/github_copilot
|
- copilot_tokens:/root/.config/litellm/github_copilot
|
||||||
@@ -21,7 +21,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: postgres
|
||||||
POSTGRES_PASSWORD: postgres
|
POSTGRES_PASSWORD: postgres
|
||||||
POSTGRES_DB: adiuva
|
POSTGRES_DB: adiuvai
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
@@ -36,37 +36,6 @@ services:
|
|||||||
# image: redis:7-alpine
|
# image: redis:7-alpine
|
||||||
# restart: unless-stopped
|
# restart: unless-stopped
|
||||||
|
|
||||||
# ── Local S3-compatible storage (MinIO) ──
|
|
||||||
minio:
|
|
||||||
image: minio/minio:latest
|
|
||||||
command: server /data --console-address ":9001"
|
|
||||||
ports:
|
|
||||||
- "9000:9000"
|
|
||||||
- "9001:9001"
|
|
||||||
environment:
|
|
||||||
MINIO_ROOT_USER: minioadmin
|
|
||||||
MINIO_ROOT_PASSWORD: minioadmin
|
|
||||||
volumes:
|
|
||||||
- minio_data:/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "mc", "ready", "local"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
# ── Local vector store (Qdrant) ──
|
|
||||||
qdrant:
|
|
||||||
image: qdrant/qdrant:latest
|
|
||||||
ports:
|
|
||||||
- "6333:6333"
|
|
||||||
- "6334:6334"
|
|
||||||
volumes:
|
|
||||||
- qdrant_data:/qdrant/storage
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
minio_data:
|
|
||||||
qdrant_data:
|
|
||||||
copilot_tokens:
|
copilot_tokens:
|
||||||
|
|||||||
@@ -32,4 +32,10 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
|
pgvector>=0.2.5
|
||||||
|
langfuse>=2.0.0
|
||||||
|
beautifulsoup4>=4.12.0
|
||||||
|
lxml>=5.0.0
|
||||||
|
PyYAML>=6.0.0
|
||||||
|
apscheduler>=3.10.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
1
results.xml
Normal file
1
results.xml
Normal file
File diff suppressed because one or more lines are too long
@@ -6,26 +6,21 @@ a per-test session, and a FastAPI ``TestClient`` wired to use it.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import boto3
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from moto import mock_aws
|
|
||||||
from sqlalchemy import StaticPool, event
|
from sqlalchemy import StaticPool, event
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.db import Base, get_session
|
from app.db import Base, get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.models import Plugin, Subscription, User
|
from app.models import Subscription, User
|
||||||
|
|
||||||
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
||||||
|
|
||||||
@@ -109,79 +104,6 @@ def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # n
|
|||||||
app.dependency_overrides.pop(get_session, None)
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
# ── Seed data helpers ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_SEED_PLUGINS = [
|
|
||||||
Plugin(
|
|
||||||
id="plugin-github-sync",
|
|
||||||
name="GitHub Sync",
|
|
||||||
description="Sync tasks with GitHub Issues and pull requests.",
|
|
||||||
version="1.0.0",
|
|
||||||
author_name="Adiuva",
|
|
||||||
category="productivity",
|
|
||||||
price_cents=0,
|
|
||||||
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
|
||||||
status="approved",
|
|
||||||
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
|
|
||||||
install_count=0,
|
|
||||||
avg_rating=0.0,
|
|
||||||
),
|
|
||||||
Plugin(
|
|
||||||
id="plugin-slack-notify",
|
|
||||||
name="Slack Notifier",
|
|
||||||
description="Post task and timeline updates to Slack channels.",
|
|
||||||
version="1.2.0",
|
|
||||||
author_name="Adiuva",
|
|
||||||
category="communication",
|
|
||||||
price_cents=499,
|
|
||||||
permissions=json.dumps(["read:tasks", "read:timelines"]),
|
|
||||||
status="approved",
|
|
||||||
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
|
||||||
install_count=0,
|
|
||||||
avg_rating=0.0,
|
|
||||||
),
|
|
||||||
Plugin(
|
|
||||||
id="plugin-time-tracker",
|
|
||||||
name="Time Tracker",
|
|
||||||
description="Track time spent on tasks with automatic reporting.",
|
|
||||||
version="0.9.1",
|
|
||||||
author_name="Third Party",
|
|
||||||
category="productivity",
|
|
||||||
price_cents=999,
|
|
||||||
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
|
||||||
status="approved",
|
|
||||||
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
|
|
||||||
install_count=0,
|
|
||||||
avg_rating=0.0,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
|
|
||||||
"""Insert the 3 default approved plugins and return them."""
|
|
||||||
plugins = []
|
|
||||||
for template in _SEED_PLUGINS:
|
|
||||||
p = Plugin(
|
|
||||||
id=template.id,
|
|
||||||
name=template.name,
|
|
||||||
description=template.description,
|
|
||||||
version=template.version,
|
|
||||||
author_name=template.author_name,
|
|
||||||
category=template.category,
|
|
||||||
price_cents=template.price_cents,
|
|
||||||
permissions=template.permissions,
|
|
||||||
status=template.status,
|
|
||||||
s3_package_key=template.s3_package_key,
|
|
||||||
install_count=template.install_count,
|
|
||||||
avg_rating=template.avg_rating,
|
|
||||||
)
|
|
||||||
db_session.add(p)
|
|
||||||
plugins.append(p)
|
|
||||||
await db_session.commit()
|
|
||||||
return plugins
|
|
||||||
|
|
||||||
|
|
||||||
# ── JWT helpers ──────────────────────────────────────────────────────
|
# ── JWT helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -212,24 +134,21 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
|
|||||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||||
|
|
||||||
|
|
||||||
# ── S3 mock fixture ──────────────────────────────────────────────────
|
# ── CLI options ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
S3_TEST_BUCKET = "test-bucket"
|
def pytest_addoption(parser):
|
||||||
S3_TEST_REGION = "us-east-1"
|
parser.addoption(
|
||||||
|
"--preprocess-dir",
|
||||||
|
default=None,
|
||||||
@pytest.fixture
|
help="Override fixture folder for preprocessor tests (must contain cases.yaml + data/)",
|
||||||
def s3_bucket():
|
)
|
||||||
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
|
parser.addoption(
|
||||||
with mock_aws():
|
"--runner-dir",
|
||||||
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
default=None,
|
||||||
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
help="Override fixture folder for agent_runner_v2 eval tests (must contain cases.yaml + data/)",
|
||||||
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
|
)
|
||||||
client = boto3.client("s3", region_name=S3_TEST_REGION)
|
parser.addoption(
|
||||||
client.create_bucket(Bucket=S3_TEST_BUCKET)
|
"--journey-dir",
|
||||||
with patch("app.storage.blob_store.settings") as mock_settings:
|
default=None,
|
||||||
mock_settings.S3_BUCKET = S3_TEST_BUCKET
|
help="Override fixture folder for journey_v2 eval tests (must contain cases.yaml + data/)",
|
||||||
mock_settings.S3_REGION = S3_TEST_REGION
|
)
|
||||||
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
|
||||||
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
|
||||||
yield S3_TEST_BUCKET
|
|
||||||
|
|||||||
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
Normal file
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Agent Runner V2 — eval test cases (Step 2, requires real LLM)
|
||||||
|
#
|
||||||
|
# Each case drives one parametrized `test_eval_runner` invocation.
|
||||||
|
#
|
||||||
|
# Keys
|
||||||
|
# ----
|
||||||
|
# id: str unique identifier shown in pytest output
|
||||||
|
# description: str human-readable label
|
||||||
|
# file: str filename inside data/
|
||||||
|
# file_path: str path reported to the executor (affects project-matching via filename)
|
||||||
|
# projects: [alpha|beta] symbolic project names resolved by the test helper
|
||||||
|
#
|
||||||
|
# Optional pre-existing records (dedup tests)
|
||||||
|
# existing_tasks: list of {id, title, status, priority}
|
||||||
|
# existing_notes: list of {id, title, content}
|
||||||
|
# existing_timelines: list of {id, title, date}
|
||||||
|
#
|
||||||
|
# Assertions (one or more)
|
||||||
|
# expect_insert: <table> at least 1 insert row in this table (tasks|notes|timelines)
|
||||||
|
# expect_no_insert: true zero inserts in any table
|
||||||
|
# expect_project_id: <id> any insert must carry this projectId
|
||||||
|
# expect_dedup: true task inserts == 0 OR task updates >= 1 (dedup check)
|
||||||
|
#
|
||||||
|
# Langfuse
|
||||||
|
# score_name: str observation score name
|
||||||
|
|
||||||
|
- id: "2.1"
|
||||||
|
description: "Action email → create_task"
|
||||||
|
file: email_action.html
|
||||||
|
file_path: /emails/ProjectAlpha_action.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_insert: tasks
|
||||||
|
score_name: runner.email_to_task
|
||||||
|
|
||||||
|
- id: "2.2"
|
||||||
|
description: "Informational email → create_note"
|
||||||
|
file: email_info.html
|
||||||
|
file_path: /emails/ProjectAlpha_info.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_insert: notes
|
||||||
|
score_name: runner.email_to_note
|
||||||
|
|
||||||
|
- id: "2.3"
|
||||||
|
description: "Email with meeting date → create_timeline"
|
||||||
|
file: email_date.html
|
||||||
|
file_path: /emails/ProjectAlpha_kickoff.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_insert: timelines
|
||||||
|
score_name: runner.email_to_timeline
|
||||||
|
|
||||||
|
- id: "2.4"
|
||||||
|
description: "Filename contains project name → correct project assigned"
|
||||||
|
file: email_action.html
|
||||||
|
file_path: /emails/ProjectAlpha_report.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_project_id: proj-alpha
|
||||||
|
score_name: runner.project_filename
|
||||||
|
|
||||||
|
- id: "2.5"
|
||||||
|
description: "Email body mentions project → correct project assigned"
|
||||||
|
file: email_action.html
|
||||||
|
file_path: /emails/email_001.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_project_id: proj-alpha
|
||||||
|
score_name: runner.project_content
|
||||||
|
|
||||||
|
- id: "2.6"
|
||||||
|
description: "Newsletter + global rule no-project → no creates"
|
||||||
|
file: email_no_project.html
|
||||||
|
file_path: /emails/newsletter.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_no_insert: true
|
||||||
|
score_name: runner.no_project
|
||||||
|
|
||||||
|
- id: "2.7"
|
||||||
|
description: "Existing task with same title → dedup (update not create)"
|
||||||
|
file: email_action.html
|
||||||
|
file_path: /emails/ProjectAlpha_followup.html
|
||||||
|
projects: [alpha]
|
||||||
|
existing_tasks:
|
||||||
|
- id: task-existing
|
||||||
|
title: Fix the login bug
|
||||||
|
status: todo
|
||||||
|
priority: medium
|
||||||
|
expect_dedup: true
|
||||||
|
score_name: runner.dedup
|
||||||
7
tests/fixtures/agent_runner_v2/data/email_action.html
vendored
Normal file
7
tests/fixtures/agent_runner_v2/data/email_action.html
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<html><head></head><body>
|
||||||
|
<p><b>From:</b> boss@company.com</p>
|
||||||
|
<p><b>To:</b> dev@company.com</p>
|
||||||
|
<p><b>Subject:</b> Fix the login bug</p>
|
||||||
|
<p><b>Date:</b> 2026-04-07</p>
|
||||||
|
<p>Hi,<br>Please fix the login bug in Project Alpha by Friday. High priority!</p>
|
||||||
|
</body></html>
|
||||||
5
tests/fixtures/agent_runner_v2/data/email_date.html
vendored
Normal file
5
tests/fixtures/agent_runner_v2/data/email_date.html
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<html><head></head><body>
|
||||||
|
<p><b>From:</b> pm@company.com</p>
|
||||||
|
<p><b>Subject:</b> Project Alpha kick-off meeting</p>
|
||||||
|
<p>The kick-off meeting for Project Alpha is scheduled for 2026-04-15 at 10:00.</p>
|
||||||
|
</body></html>
|
||||||
7
tests/fixtures/agent_runner_v2/data/email_info.html
vendored
Normal file
7
tests/fixtures/agent_runner_v2/data/email_info.html
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<html><head></head><body>
|
||||||
|
<p><b>From:</b> pm@company.com</p>
|
||||||
|
<p><b>To:</b> team@company.com</p>
|
||||||
|
<p><b>Subject:</b> FYI: New policy for Project Alpha</p>
|
||||||
|
<p>Just a heads-up that starting next week all code reviews must be done
|
||||||
|
within 24 hours for Project Alpha. No action needed from you now.</p>
|
||||||
|
</body></html>
|
||||||
5
tests/fixtures/agent_runner_v2/data/email_no_project.html
vendored
Normal file
5
tests/fixtures/agent_runner_v2/data/email_no_project.html
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<html><head></head><body>
|
||||||
|
<p><b>From:</b> newsletter@ads.com</p>
|
||||||
|
<p><b>Subject:</b> Weekly newsletter</p>
|
||||||
|
<p>Check out our latest deals on electronics!</p>
|
||||||
|
</body></html>
|
||||||
19
tests/fixtures/journey_v2/cases.yaml
vendored
Normal file
19
tests/fixtures/journey_v2/cases.yaml
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Journey V2 eval test cases — Step 4
|
||||||
|
#
|
||||||
|
# Only case 4.1 is kept as an automated eval. Cases 4.2–4.5 (multi-turn
|
||||||
|
# conversations that expect the LLM to produce a complete AgentConfig)
|
||||||
|
# are non-deterministic and tested manually — results tracked in Langfuse.
|
||||||
|
#
|
||||||
|
# Assertion keys:
|
||||||
|
# expect_question: true → first reply must contain "?"
|
||||||
|
|
||||||
|
- id: "4.1"
|
||||||
|
description: "Journey start explores directory, first reply contains a question"
|
||||||
|
directory: "/test/emails"
|
||||||
|
data_types: ["tasks", "notes", "timelines"]
|
||||||
|
directory_files:
|
||||||
|
- path: "/test/emails/outlook_export_2024.html"
|
||||||
|
content_file: "email_action.html"
|
||||||
|
user_messages: []
|
||||||
|
score_name: "journey.start"
|
||||||
|
expect_question: true
|
||||||
23
tests/fixtures/journey_v2/data/email_action.html
vendored
Normal file
23
tests/fixtures/journey_v2/data/email_action.html
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Email: Fix the login bug</title>
|
||||||
|
<style>body { font-family: Arial; } .header { color: #666; }</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="header">
|
||||||
|
<p><strong>From:</strong> boss@company.com</p>
|
||||||
|
<p><strong>To:</strong> dev@company.com</p>
|
||||||
|
<p><strong>Subject:</strong> Fix the login bug</p>
|
||||||
|
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:15:00 +0000</p>
|
||||||
|
</div>
|
||||||
|
<div class="body">
|
||||||
|
<p>Hi,</p>
|
||||||
|
<p>Please fix the login bug in Project Alpha as soon as possible.
|
||||||
|
Users are reporting that they can't log in with their Google accounts.
|
||||||
|
This is blocking the whole team. Please resolve it by Friday.</p>
|
||||||
|
<p>Thanks,<br>Boss</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
23
tests/fixtures/journey_v2/data/email_info.html
vendored
Normal file
23
tests/fixtures/journey_v2/data/email_info.html
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Email: New policy update</title>
|
||||||
|
<style>body { font-family: Arial; }</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="header">
|
||||||
|
<p><strong>From:</strong> hr@company.com</p>
|
||||||
|
<p><strong>To:</strong> all@company.com</p>
|
||||||
|
<p><strong>Subject:</strong> FYI: New remote work policy effective May 1</p>
|
||||||
|
<p><strong>Date:</strong> Tue, 8 Apr 2026 10:00:00 +0000</p>
|
||||||
|
</div>
|
||||||
|
<div class="body">
|
||||||
|
<p>Hi everyone,</p>
|
||||||
|
<p>Just a heads-up that starting May 1, 2026 the company will be moving to
|
||||||
|
a hybrid work model. You will be expected to come into the office at least
|
||||||
|
two days per week. More details will follow in the employee handbook.</p>
|
||||||
|
<p>Best,<br>HR Team</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
68
tests/fixtures/preprocessors/cases.yaml
vendored
Normal file
68
tests/fixtures/preprocessors/cases.yaml
vendored
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# Preprocessor test cases
|
||||||
|
#
|
||||||
|
# detect: <expected_type> → chiama detect_content_type(filename, content)
|
||||||
|
# process: <content_type> → chiama preprocess(content_type, content)
|
||||||
|
#
|
||||||
|
# Sorgente: file: <nome in data/> oppure generate: binary_noise
|
||||||
|
#
|
||||||
|
# Assertions piatte (solo per process):
|
||||||
|
# no_html: true clean_text senza tag HTML
|
||||||
|
# min_chars: N len(clean_text) >= N
|
||||||
|
# ratio_lt: F len(clean) / len(raw) < F
|
||||||
|
# has_meta: [k, ...] chiavi presenti in metadata
|
||||||
|
# contains: str | [str] substring(s) presenti in clean_text
|
||||||
|
# excludes: str | [str] substring(s) assenti da clean_text
|
||||||
|
# content_type: str result.content_type == questo valore
|
||||||
|
|
||||||
|
- id: "1.1"
|
||||||
|
file: email_action.html
|
||||||
|
detect: email_html
|
||||||
|
|
||||||
|
- id: "1.2"
|
||||||
|
file: generic_page.html
|
||||||
|
detect: generic_html
|
||||||
|
|
||||||
|
- id: "1.3"
|
||||||
|
file: notes.txt
|
||||||
|
detect: plain_text
|
||||||
|
|
||||||
|
- id: "1.4"
|
||||||
|
file: archive.xyz
|
||||||
|
generate: binary_noise
|
||||||
|
detect: unknown
|
||||||
|
|
||||||
|
- id: "1.5"
|
||||||
|
file: email_action.html
|
||||||
|
process: email_html
|
||||||
|
no_html: true
|
||||||
|
min_chars: 50
|
||||||
|
ratio_lt: 0.8
|
||||||
|
|
||||||
|
- id: "1.6"
|
||||||
|
file: email_action.html
|
||||||
|
process: email_html
|
||||||
|
has_meta: [subject, from]
|
||||||
|
|
||||||
|
- id: "1.7"
|
||||||
|
file: email_thread.html
|
||||||
|
process: email_html
|
||||||
|
contains: "Sure, I'll handle the deploy"
|
||||||
|
excludes: "Let's plan the deploy"
|
||||||
|
|
||||||
|
- id: "1.8"
|
||||||
|
file: email_single.html
|
||||||
|
process: email_html
|
||||||
|
contains: "deploy is done"
|
||||||
|
|
||||||
|
- id: "1.9"
|
||||||
|
file: email_heavy.html
|
||||||
|
process: email_html
|
||||||
|
no_html: true
|
||||||
|
min_chars: 30
|
||||||
|
excludes: [border-collapse, font-size]
|
||||||
|
|
||||||
|
- id: "1.10"
|
||||||
|
file: fallback.txt
|
||||||
|
process: unknown
|
||||||
|
min_chars: 1
|
||||||
|
content_type: unknown
|
||||||
25
tests/fixtures/preprocessors/data/email_action.html
vendored
Normal file
25
tests/fixtures/preprocessors/data/email_action.html
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Fix the login bug</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: Arial, sans-serif; color: #333; margin: 0; padding: 20px; }
|
||||||
|
.header { background: #f5f5f5; padding: 10px; border-bottom: 1px solid #ddd; }
|
||||||
|
.body { padding: 20px; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="header">
|
||||||
|
<p><strong>From:</strong> boss@company.com</p>
|
||||||
|
<p><strong>To:</strong> dev@company.com</p>
|
||||||
|
<p><strong>Subject:</strong> Fix the login bug</p>
|
||||||
|
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:00:00 +0200</p>
|
||||||
|
</div>
|
||||||
|
<div class="body">
|
||||||
|
<p>Hi,</p>
|
||||||
|
<p>Please fix the login bug by Friday. It is blocking the release.</p>
|
||||||
|
<p>Priority: high. Let me know if you need anything.</p>
|
||||||
|
<p>Thanks,<br>Boss</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
49
tests/fixtures/preprocessors/data/email_heavy.html
vendored
Normal file
49
tests/fixtures/preprocessors/data/email_heavy.html
vendored
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
table { border-collapse: collapse; width: 100%; max-width: 600px; margin: 0 auto; }
|
||||||
|
td { padding: 8px 12px; border: 1px solid #dddddd; font-size: 12px; color: #444444; }
|
||||||
|
.header-row { background-color: #003366; color: #ffffff; font-weight: bold; }
|
||||||
|
.label-col { background-color: #f0f0f0; width: 80px; font-weight: bold; }
|
||||||
|
.footer-row { font-size: 10px; color: #999999; text-align: center; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body bgcolor="#eeeeee">
|
||||||
|
<center>
|
||||||
|
<table cellpadding="0" cellspacing="0">
|
||||||
|
<tr class="header-row">
|
||||||
|
<td colspan="2">Company Internal Update</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="label-col">From:</td>
|
||||||
|
<td>newsletter@corp.com</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="label-col">Subject:</td>
|
||||||
|
<td>Q1 Results Update</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="label-col">Date:</td>
|
||||||
|
<td>Apr 7, 2026</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td colspan="2">
|
||||||
|
<table width="100%" cellpadding="10">
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<p style="font-size:14px; font-weight:bold;">Dear Team,</p>
|
||||||
|
<p>Q1 results are in. Revenue up 15% year-over-year.</p>
|
||||||
|
<p>Please review the attached report and share any feedback by EOW.</p>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr class="footer-row">
|
||||||
|
<td colspan="2">Confidential — do not forward outside the company.</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</center>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
8
tests/fixtures/preprocessors/data/email_single.html
vendored
Normal file
8
tests/fixtures/preprocessors/data/email_single.html
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html><body>
|
||||||
|
<p><strong>From:</strong> alice@co.com</p>
|
||||||
|
<p><strong>To:</strong> team@co.com</p>
|
||||||
|
<p><strong>Subject:</strong> Quick update</p>
|
||||||
|
<p><strong>Date:</strong> Tue, 7 Apr 2026 10:30:00 +0200</p>
|
||||||
|
<p>The deploy is done. Everything looks good. No issues so far.</p>
|
||||||
|
</body></html>
|
||||||
24
tests/fixtures/preprocessors/data/email_thread.html
vendored
Normal file
24
tests/fixtures/preprocessors/data/email_thread.html
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html><body>
|
||||||
|
<div class="message-latest">
|
||||||
|
<p><strong>From:</strong> alice@co.com</p>
|
||||||
|
<p><strong>Subject:</strong> Re: Re: Deploy plan</p>
|
||||||
|
<p>Sure, I'll handle the deploy.</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p>On Mon, Apr 6, 2026 at 3:00 PM, Bob <bob@co.com> wrote:</p>
|
||||||
|
<blockquote>
|
||||||
|
<p>From: bob@co.com</p>
|
||||||
|
<p>Can you handle the deploy?</p>
|
||||||
|
<p>On Sun, Apr 5, 2026 at 1:00 PM, Alice <alice@co.com> wrote:</p>
|
||||||
|
<blockquote>
|
||||||
|
<p>From: alice@co.com</p>
|
||||||
|
<p>Let's plan the deploy for Monday.</p>
|
||||||
|
<p>On Sat, Apr 4, 2026 at 11:00 AM, Charlie <charlie@co.com> wrote:</p>
|
||||||
|
<blockquote>
|
||||||
|
<p>From: charlie@co.com</p>
|
||||||
|
<p>We need to schedule the deploy. What day works?</p>
|
||||||
|
</blockquote>
|
||||||
|
</blockquote>
|
||||||
|
</blockquote>
|
||||||
|
</body></html>
|
||||||
3
tests/fixtures/preprocessors/data/fallback.txt
vendored
Normal file
3
tests/fixtures/preprocessors/data/fallback.txt
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
random text content without any structure
|
||||||
|
line two with some words
|
||||||
|
line three and more content here
|
||||||
35
tests/fixtures/preprocessors/data/generic_page.html
vendored
Normal file
35
tests/fixtures/preprocessors/data/generic_page.html
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>My Web App</title>
|
||||||
|
<link rel="stylesheet" href="styles.css">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<nav>
|
||||||
|
<a href="/">Home</a>
|
||||||
|
<a href="/about">About</a>
|
||||||
|
<a href="/contact">Contact</a>
|
||||||
|
</nav>
|
||||||
|
<main>
|
||||||
|
<header>
|
||||||
|
<h1>Welcome to My App</h1>
|
||||||
|
</header>
|
||||||
|
<article>
|
||||||
|
<p>This is a generic web page with no email headers.</p>
|
||||||
|
<p>It has navigation, main content, and a footer.</p>
|
||||||
|
</article>
|
||||||
|
<section>
|
||||||
|
<h2>Features</h2>
|
||||||
|
<ul>
|
||||||
|
<li>Fast</li>
|
||||||
|
<li>Reliable</li>
|
||||||
|
<li>Secure</li>
|
||||||
|
</ul>
|
||||||
|
</section>
|
||||||
|
</main>
|
||||||
|
<footer>
|
||||||
|
<p>© 2026 My App</p>
|
||||||
|
</footer>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
15
tests/fixtures/preprocessors/data/notes.txt
vendored
Normal file
15
tests/fixtures/preprocessors/data/notes.txt
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
Meeting notes - April 7, 2026
|
||||||
|
|
||||||
|
Attendees: Alice, Bob, Charlie
|
||||||
|
|
||||||
|
Discussion points:
|
||||||
|
- Deploy scheduled for Friday
|
||||||
|
- Bug fix for login must be completed by Thursday
|
||||||
|
- Review Q1 numbers before EOW
|
||||||
|
|
||||||
|
Action items:
|
||||||
|
- Alice: fix login bug
|
||||||
|
- Bob: prepare deploy checklist
|
||||||
|
- Charlie: send Q1 report
|
||||||
|
|
||||||
|
Next meeting: April 14, 2026
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
"""Unit tests for the agent registry, base classes, and tool loop."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _StubAgent(ChatAgent):
|
|
||||||
"""Minimal concrete agent for testing."""
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "stub"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "A stub agent for tests"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"echo: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
class _AnotherAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "another"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Another stub"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return "another"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _fresh_registry():
|
|
||||||
"""Reset the singleton between tests."""
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
yield
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def reg() -> AgentRegistry:
|
|
||||||
return AgentRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tests ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestRegisterAndGet:
|
|
||||||
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
agent = reg.get("stub")
|
|
||||||
assert isinstance(agent, _StubAgent)
|
|
||||||
|
|
||||||
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError, match="not found"):
|
|
||||||
reg.get("nonexistent")
|
|
||||||
|
|
||||||
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
reg.register(_AnotherAgent)
|
|
||||||
assert reg.get("stub").get_name() == "stub"
|
|
||||||
assert reg.get("another").get_name() == "another"
|
|
||||||
|
|
||||||
|
|
||||||
class TestListAgents:
|
|
||||||
def test_empty(self, reg: AgentRegistry) -> None:
|
|
||||||
assert reg.list_agents() == []
|
|
||||||
|
|
||||||
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
agents = reg.list_agents()
|
|
||||||
assert len(agents) == 1
|
|
||||||
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
|
||||||
|
|
||||||
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
reg.register(_AnotherAgent)
|
|
||||||
names = {a["name"] for a in reg.list_agents()}
|
|
||||||
assert names == {"stub", "another"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestCallAgent:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
result = await reg.call_agent("stub", "hello", {})
|
|
||||||
assert result == "echo: hello"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
await reg.call_agent("nope", "hi", {})
|
|
||||||
|
|
||||||
|
|
||||||
class TestSingleton:
|
|
||||||
def test_singleton_identity(self) -> None:
|
|
||||||
a = AgentRegistry()
|
|
||||||
b = AgentRegistry()
|
|
||||||
assert a is b
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolLoop:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_tool_calls(self) -> None:
|
|
||||||
"""When the LLM responds without tool calls, return content directly."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
ai_msg = MagicMock()
|
|
||||||
ai_msg.content = "final answer"
|
|
||||||
ai_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [])
|
|
||||||
assert result == "final answer"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_call_then_answer(self) -> None:
|
|
||||||
"""LLM requests one tool call, gets result, then answers."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
# First response: tool call
|
|
||||||
tool_call_msg = MagicMock()
|
|
||||||
tool_call_msg.content = ""
|
|
||||||
tool_call_msg.tool_calls = [
|
|
||||||
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Second response: final answer
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "done"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
# Mock tool
|
|
||||||
tool = AsyncMock()
|
|
||||||
tool.name = "my_tool"
|
|
||||||
tool.ainvoke = AsyncMock(return_value="tool_result")
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [tool])
|
|
||||||
assert result == "done"
|
|
||||||
tool.ainvoke.assert_called_once_with({"x": 1})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unknown_tool_handled(self) -> None:
|
|
||||||
"""Unknown tool names produce an error message instead of crashing."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
tool_call_msg = MagicMock()
|
|
||||||
tool_call_msg.content = ""
|
|
||||||
tool_call_msg.tool_calls = [
|
|
||||||
{"id": "call_1", "name": "missing", "args": {}}
|
|
||||||
]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "recovered"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [])
|
|
||||||
assert result == "recovered"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_max_iter_reached(self) -> None:
|
|
||||||
"""When max iterations are exhausted, a final no-tools call is made."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
# Every response requests a tool call
|
|
||||||
loop_msg = MagicMock()
|
|
||||||
loop_msg.content = ""
|
|
||||||
loop_msg.tool_calls = [
|
|
||||||
{"id": "call_x", "name": "t", "args": {}}
|
|
||||||
]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "gave up"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
tool = AsyncMock()
|
|
||||||
tool.name = "t"
|
|
||||||
tool.ainvoke = AsyncMock(return_value="ok")
|
|
||||||
|
|
||||||
llm_with_tools = AsyncMock()
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
|
||||||
assert result == "gave up"
|
|
||||||
assert llm_with_tools.ainvoke.call_count == 2
|
|
||||||
@@ -1,871 +0,0 @@
|
|||||||
"""Tests for Step 3.4: agent_runner module.
|
|
||||||
|
|
||||||
Coverage:
|
|
||||||
Unit:
|
|
||||||
- _is_overdue — cron schedule overdue detection
|
|
||||||
- _extract_items_from_content — LLM extraction + JSON parsing + validation
|
|
||||||
- _send_insert_to_client — tool_call frame construction + timeout
|
|
||||||
- run_local_agent — end-to-end local agent happy path
|
|
||||||
- run_local_agent — device offline path
|
|
||||||
- run_local_agent — file-read timeout path
|
|
||||||
- run_local_agent — LLM extraction error path
|
|
||||||
- run_cloud_agent — stub returns error immediately
|
|
||||||
- trigger_pending_runs — overdue local + cloud dispatched
|
|
||||||
- trigger_pending_runs — non-overdue skipped
|
|
||||||
- trigger_pending_runs — device_id filter for local agents
|
|
||||||
|
|
||||||
Integration:
|
|
||||||
- POST /agents/{id}/run — 404 on unknown agent
|
|
||||||
- POST /agents/{id}/run — creates run log + dispatches background task
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from app.core.agent_runner import (
|
|
||||||
_extract_items_from_content,
|
|
||||||
_is_overdue,
|
|
||||||
_send_insert_to_client,
|
|
||||||
run_cloud_agent,
|
|
||||||
run_local_agent,
|
|
||||||
trigger_pending_runs,
|
|
||||||
)
|
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
|
||||||
from app.db import get_session
|
|
||||||
from app.main import app
|
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
|
||||||
from tests.conftest import TEST_USER_IDS, auth_header
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_FREE_UID = TEST_USER_IDS["free"]
|
|
||||||
_PRO_UID = TEST_USER_IDS["pro"]
|
|
||||||
|
|
||||||
|
|
||||||
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
|
|
||||||
return LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
device_id=device_id,
|
|
||||||
name="Test Local Agent",
|
|
||||||
directory_paths=["/home/user/emails"],
|
|
||||||
data_types=["tasks", "notes"],
|
|
||||||
prompt_template="Extract tasks and notes from this document.",
|
|
||||||
file_extensions=[".txt", ".eml"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
last_run_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
|
|
||||||
return CloudAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
provider="gmail",
|
|
||||||
name="Test Gmail Agent",
|
|
||||||
data_types=["tasks"],
|
|
||||||
prompt_template="Extract tasks from email.",
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
last_run_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
|
|
||||||
return AgentRunLog(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_type=agent_type,
|
|
||||||
user_id=user_id,
|
|
||||||
status="running",
|
|
||||||
started_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
|
|
||||||
mgr = DeviceConnectionManager()
|
|
||||||
ws = MagicMock()
|
|
||||||
ws.send_text = AsyncMock()
|
|
||||||
mgr.register(user_id, device_id, ws)
|
|
||||||
return mgr
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _is_overdue
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def test_is_overdue_never_run():
|
|
||||||
"""An agent that has never run is always overdue."""
|
|
||||||
assert _is_overdue("0 */6 * * *", None) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_overdue_very_recently_run():
|
|
||||||
"""An agent that just ran is not overdue."""
|
|
||||||
last = datetime.now(timezone.utc)
|
|
||||||
assert _is_overdue("0 */6 * * *", last) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_overdue_long_ago():
|
|
||||||
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
|
|
||||||
from datetime import timedelta
|
|
||||||
last = datetime.now(timezone.utc) - timedelta(days=2)
|
|
||||||
assert _is_overdue("0 */6 * * *", last) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_overdue_invalid_cron_returns_false():
|
|
||||||
"""Unparseable cron must not raise and should return False (fail-safe)."""
|
|
||||||
assert _is_overdue("not a cron", None) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_overdue_naive_datetime():
|
|
||||||
"""Naive datetime objects are handled without raising."""
|
|
||||||
from datetime import timedelta
|
|
||||||
last = datetime.utcnow() - timedelta(days=1) # naive
|
|
||||||
# Should not raise.
|
|
||||||
result = _is_overdue("0 */6 * * *", last)
|
|
||||||
assert isinstance(result, bool)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _extract_items_from_content
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_happy_path():
|
|
||||||
"""LLM returns valid JSON array; items with allowed tables are returned."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = json.dumps([
|
|
||||||
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
|
|
||||||
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
|
|
||||||
])
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
items = await _extract_items_from_content(
|
|
||||||
"Extract tasks and notes.",
|
|
||||||
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
|
|
||||||
["tasks", "notes"],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(items) == 2
|
|
||||||
assert items[0]["table"] == "tasks"
|
|
||||||
assert items[0]["data"]["title"] == "Buy milk"
|
|
||||||
assert items[1]["table"] == "notes"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_strips_forbidden_fields():
|
|
||||||
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = json.dumps([
|
|
||||||
{
|
|
||||||
"table": "tasks",
|
|
||||||
"data": {
|
|
||||||
"title": "Review PR",
|
|
||||||
"id": "should-be-removed",
|
|
||||||
"createdAt": 99999,
|
|
||||||
"isAiSuggested": 0,
|
|
||||||
"isApproved": 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
])
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
|
|
||||||
|
|
||||||
assert len(items) == 1
|
|
||||||
data = items[0]["data"]
|
|
||||||
assert "id" not in data
|
|
||||||
assert "createdAt" not in data
|
|
||||||
assert "isAiSuggested" not in data
|
|
||||||
assert "isApproved" not in data
|
|
||||||
assert data["title"] == "Review PR"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_invalid_json_returns_empty():
|
|
||||||
"""LLM returning invalid JSON must return empty list without raising."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = "Sorry, I cannot extract anything."
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
|
||||||
|
|
||||||
assert items == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_disallowed_table_filtered():
|
|
||||||
"""Items whose table is not in data_types are discarded."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = json.dumps([
|
|
||||||
{"table": "tasks", "data": {"title": "Valid task"}},
|
|
||||||
{"table": "projects", "data": {"name": "Should be filtered"}},
|
|
||||||
])
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
# Only "tasks" is in data_types — "projects" should be filtered.
|
|
||||||
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
|
|
||||||
|
|
||||||
assert len(items) == 1
|
|
||||||
assert items[0]["table"] == "tasks"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_empty_data_types_returns_empty():
|
|
||||||
"""If no allowed data_types match, skip LLM call and return immediately."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.ainvoke = AsyncMock()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
items = await _extract_items_from_content("Extract.", "content", [])
|
|
||||||
|
|
||||||
mock_llm.ainvoke.assert_not_called()
|
|
||||||
assert items == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_items_llm_error_propagates():
|
|
||||||
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
||||||
with pytest.raises(RuntimeError, match="API unavailable"):
|
|
||||||
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _send_insert_to_client
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_insert_to_client_happy_path():
|
|
||||||
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
sent_payloads: list[dict] = []
|
|
||||||
original_send = mgr.send_frame
|
|
||||||
|
|
||||||
async def _capture_send(uid: str, frame: dict) -> None:
|
|
||||||
sent_payloads.append(frame)
|
|
||||||
# Immediately resolve the pending call with a success result.
|
|
||||||
call_id = frame["id"]
|
|
||||||
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
|
|
||||||
|
|
||||||
mgr.send_frame = _capture_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
result = await _send_insert_to_client(
|
|
||||||
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(sent_payloads) == 1
|
|
||||||
payload = sent_payloads[0]
|
|
||||||
assert payload["action"] == "insert"
|
|
||||||
assert payload["table"] == "tasks"
|
|
||||||
assert payload["data"]["title"] == "Buy milk"
|
|
||||||
assert payload["data"]["isAiSuggested"] == 1
|
|
||||||
assert payload["data"]["isApproved"] == 0
|
|
||||||
assert result["row"]["title"] == "Buy milk"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_insert_to_client_timeout():
|
|
||||||
"""asyncio.TimeoutError is raised when Electron does not respond."""
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
async def _slow_send(uid: str, frame: dict) -> None:
|
|
||||||
# Never resolve the pending call.
|
|
||||||
pass
|
|
||||||
|
|
||||||
mgr.send_frame = _slow_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# run_local_agent
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_local_agent_device_offline():
|
|
||||||
"""run_local_agent marks run as error when device is offline."""
|
|
||||||
config = _make_local_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = DeviceConnectionManager() # Empty — no device registered.
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("not connected" in e for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_local_agent_happy_path():
|
|
||||||
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
|
|
||||||
config = _make_local_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
# Build a fake agent_data frame (will be queued after send).
|
|
||||||
file_frame = {
|
|
||||||
"type": "agent_data",
|
|
||||||
"run_id": run_log.id,
|
|
||||||
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
|
|
||||||
}
|
|
||||||
agent_complete_frame = None # sentinel
|
|
||||||
|
|
||||||
sent_frames: list[dict] = []
|
|
||||||
|
|
||||||
async def _mock_send(uid: str, frame: dict) -> None:
|
|
||||||
sent_frames.append(frame)
|
|
||||||
if frame.get("type") == "agent_run":
|
|
||||||
# Simulate Electron responding with file data then agent_complete.
|
|
||||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
|
||||||
await q.put(file_frame)
|
|
||||||
await q.put(agent_complete_frame)
|
|
||||||
elif frame.get("type") == "tool_call":
|
|
||||||
# Resolve the pending insert immediately.
|
|
||||||
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
|
|
||||||
|
|
||||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.content = json.dumps([
|
|
||||||
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
|
|
||||||
])
|
|
||||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "success"
|
|
||||||
assert kwargs["items_processed"] == 1
|
|
||||||
assert kwargs["items_created"] == 1
|
|
||||||
assert kwargs["errors"] == []
|
|
||||||
assert kwargs["update_config_last_run"] is True
|
|
||||||
|
|
||||||
# Verify agent_run frame was sent.
|
|
||||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
|
||||||
assert len(agent_run_frames) == 1
|
|
||||||
assert agent_run_frames[0]["agent_id"] == config.id
|
|
||||||
assert "paths" in agent_run_frames[0]["config"]
|
|
||||||
|
|
||||||
# Verify insert frame was sent with AI flags.
|
|
||||||
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
|
|
||||||
assert len(insert_frames) == 1
|
|
||||||
assert insert_frames[0]["data"]["isAiSuggested"] == 1
|
|
||||||
assert insert_frames[0]["data"]["isApproved"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_local_agent_file_read_timeout():
|
|
||||||
"""run_local_agent marks run as partial/error when device stops sending files."""
|
|
||||||
config = _make_local_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
async def _mock_send(uid: str, frame: dict) -> None:
|
|
||||||
# Don't put anything in the queue — simulate stalled device.
|
|
||||||
pass
|
|
||||||
|
|
||||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error" # No items created, so error (not partial).
|
|
||||||
assert any("timed out" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_local_agent_llm_extraction_error():
|
|
||||||
"""LLM errors per-file are recorded; run continues for remaining files."""
|
|
||||||
config = _make_local_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
file_frame = {
|
|
||||||
"type": "agent_data",
|
|
||||||
"run_id": run_log.id,
|
|
||||||
"files": [
|
|
||||||
{"path": "/file1.eml", "content": "Email one."},
|
|
||||||
{"path": "/file2.eml", "content": "Email two."},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _mock_send(uid: str, frame: dict) -> None:
|
|
||||||
if frame.get("type") == "agent_run":
|
|
||||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
|
||||||
await q.put(file_frame)
|
|
||||||
await q.put(None) # agent_complete sentinel
|
|
||||||
|
|
||||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
|
||||||
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert kwargs["items_processed"] == 2 # Both files attempted.
|
|
||||||
assert kwargs["items_created"] == 0
|
|
||||||
assert len(kwargs["errors"]) == 2 # One error per file.
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# run_cloud_agent (stub)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_device_offline():
|
|
||||||
"""Cloud agent aborts immediately when no device is connected."""
|
|
||||||
config = _make_cloud_config()
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = DeviceConnectionManager() # empty — no devices registered
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_no_oauth_token():
|
|
||||||
"""Cloud agent errors when no OAuth token is stored."""
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.oauth_token_encrypted = None
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_token_decrypt_failure():
|
|
||||||
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
from cryptography.fernet import Fernet as _Fernet
|
|
||||||
valid_key = _Fernet.generate_key().decode()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
|
||||||
patch("app.integrations.settings") as mock_settings:
|
|
||||||
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_happy_path_gmail():
|
|
||||||
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
|
||||||
from app.integrations import EmailMessage, encrypt_token
|
|
||||||
from cryptography.fernet import Fernet as _Fernet
|
|
||||||
|
|
||||||
fernet_key = _Fernet.generate_key().decode()
|
|
||||||
credentials = {
|
|
||||||
"token": "access_abc",
|
|
||||||
"refresh_token": "refresh_xyz",
|
|
||||||
"token_uri": "https://oauth2.googleapis.com/token",
|
|
||||||
"client_id": "cid",
|
|
||||||
"client_secret": "csec",
|
|
||||||
}
|
|
||||||
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.provider = "gmail"
|
|
||||||
config.prompt_template = "Extract tasks from this email."
|
|
||||||
config.data_types = ["tasks"]
|
|
||||||
|
|
||||||
with patch("app.integrations.settings") as ms:
|
|
||||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
||||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
|
||||||
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
sample_email = EmailMessage(
|
|
||||||
id="msg001",
|
|
||||||
subject="Action required",
|
|
||||||
sender="boss@company.com",
|
|
||||||
body_text="Please fix the bug by Friday.",
|
|
||||||
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
|
||||||
|
|
||||||
with patch("app.integrations.settings") as mock_int_settings, \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
|
||||||
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
|
||||||
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
|
||||||
patch("app.core.agent_runner.async_session"):
|
|
||||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
||||||
|
|
||||||
mock_gmail = AsyncMock()
|
|
||||||
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
|
||||||
mock_gmail.refreshed_credentials = None
|
|
||||||
|
|
||||||
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
|
||||||
patch("app.integrations.get_provider", return_value=mock_gmail):
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
mock_extract.assert_called_once()
|
|
||||||
mock_insert.assert_called_once()
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "success"
|
|
||||||
assert kwargs["items_processed"] == 1
|
|
||||||
assert kwargs["items_created"] == 1
|
|
||||||
assert kwargs["config_type"] == "cloud"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_provider_fetch_error():
|
|
||||||
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
|
||||||
credentials = {"token": "abc"}
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
|
||||||
config.prompt_template = "Extract tasks."
|
|
||||||
config.data_types = ["tasks"]
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
mock_provider = AsyncMock()
|
|
||||||
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
|
||||||
mock_provider.refreshed_credentials = None
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
|
||||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
|
||||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
|
||||||
patch("app.core.agent_runner.async_session"):
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_cloud_agent_refreshed_token_persisted():
|
|
||||||
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
|
||||||
from app.integrations import EmailMessage, encrypt_token
|
|
||||||
from cryptography.fernet import Fernet as _Fernet
|
|
||||||
|
|
||||||
fernet_key = _Fernet.generate_key().decode()
|
|
||||||
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
|
||||||
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
|
||||||
|
|
||||||
config = _make_cloud_config()
|
|
||||||
config.prompt_template = "Extract tasks."
|
|
||||||
config.data_types = ["tasks"]
|
|
||||||
|
|
||||||
with patch("app.integrations.settings") as ms:
|
|
||||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
||||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
|
||||||
|
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
mock_provider = AsyncMock()
|
|
||||||
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
|
||||||
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
|
||||||
|
|
||||||
# Track DB writes via mock async_session.
|
|
||||||
mock_cfg_row = MagicMock()
|
|
||||||
mock_cfg_row.oauth_token_encrypted = None
|
|
||||||
|
|
||||||
mock_db = AsyncMock()
|
|
||||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
|
||||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
|
||||||
cfg_result = MagicMock()
|
|
||||||
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
|
||||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
|
||||||
mock_db.commit = AsyncMock()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
|
||||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
|
||||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
|
||||||
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
|
||||||
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
|
||||||
patch("app.integrations.settings") as mock_int_settings:
|
|
||||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
||||||
|
|
||||||
# The new encrypted token should have been written to the config row.
|
|
||||||
mock_encrypt.assert_called_once_with(fresh_credentials)
|
|
||||||
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_finalize_run_updates_cloud_config_last_run_at():
|
|
||||||
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
|
||||||
from app.core.agent_runner import _finalize_run
|
|
||||||
|
|
||||||
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
|
||||||
run_log.id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
mock_cfg = MagicMock()
|
|
||||||
mock_cfg.last_run_at = None
|
|
||||||
|
|
||||||
cfg_result = MagicMock()
|
|
||||||
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
|
||||||
|
|
||||||
mock_db = AsyncMock()
|
|
||||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
|
||||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_db.merge = AsyncMock(return_value=run_log)
|
|
||||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
|
||||||
mock_db.commit = AsyncMock()
|
|
||||||
|
|
||||||
config_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
|
||||||
await _finalize_run(
|
|
||||||
run_log,
|
|
||||||
status="success",
|
|
||||||
update_config_last_run=True,
|
|
||||||
config_id=config_id,
|
|
||||||
config_type="cloud",
|
|
||||||
)
|
|
||||||
|
|
||||||
# CloudAgentConfig.last_run_at should have been set.
|
|
||||||
assert mock_cfg.last_run_at is not None
|
|
||||||
mock_db.commit.assert_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# trigger_pending_runs
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_pending_runs_no_overdue():
|
|
||||||
"""If no agents are overdue trigger_pending_runs does nothing."""
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
config = _make_local_config()
|
|
||||||
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
|
||||||
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
|
||||||
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
|
||||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value = mock_ctx
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_pending_runs_device_id_filter():
|
|
||||||
"""Local agents are only triggered for the matching device_id."""
|
|
||||||
# The DB query already filters by device_id, so we verify the SELECT
|
|
||||||
# includes the device_id filter by checking that a config bound to a
|
|
||||||
# different device is never dispatched.
|
|
||||||
#
|
|
||||||
# Since trigger_pending_runs queries with device_id == "dev-001",
|
|
||||||
# simulate the DB returning an empty list (as it would for a mismatch).
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager(device_id="dev-001")
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
|
||||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value = mock_ctx
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_pending_runs_dispatches_overdue():
|
|
||||||
"""Overdue local agent triggers run_local_agent sequentially."""
|
|
||||||
config = _make_local_config() # last_run_at=None → always overdue
|
|
||||||
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
call_order: list[str] = []
|
|
||||||
|
|
||||||
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
|
||||||
call_order.append("run_local")
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
|
||||||
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
|
||||||
# First call: query configs. Subsequent calls: create run_log.
|
|
||||||
mock_query_ctx = AsyncMock()
|
|
||||||
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
|
||||||
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_query_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
|
|
||||||
run_log_obj = AgentRunLog(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
agent_id=config.id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=_FREE_UID,
|
|
||||||
status="running",
|
|
||||||
started_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
mock_insert_ctx = AsyncMock()
|
|
||||||
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
|
||||||
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_insert_ctx.add = MagicMock()
|
|
||||||
mock_insert_ctx.commit = AsyncMock()
|
|
||||||
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
|
||||||
|
|
||||||
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
|
||||||
|
|
||||||
assert call_order == ["run_local"]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Integration: POST /agents/{id}/run
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _override_db(db_session):
|
|
||||||
"""Route all get_session calls to the test SQLite session."""
|
|
||||||
|
|
||||||
async def _gen():
|
|
||||||
yield db_session
|
|
||||||
|
|
||||||
app.dependency_overrides[get_session] = _gen
|
|
||||||
yield
|
|
||||||
app.dependency_overrides.pop(get_session, None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_run_unknown_agent(client):
|
|
||||||
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
|
||||||
resp = client.post(
|
|
||||||
f"/api/v1/agents/{uuid.uuid4()}/run",
|
|
||||||
headers=auth_header("power"),
|
|
||||||
)
|
|
||||||
assert resp.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
|
||||||
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
|
||||||
# Create the local agent config in the DB.
|
|
||||||
config = LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=TEST_USER_IDS["power"],
|
|
||||||
device_id="dev-001",
|
|
||||||
name="My Agent",
|
|
||||||
directory_paths=["/home/user/docs"],
|
|
||||||
data_types=["tasks"],
|
|
||||||
prompt_template="Extract tasks.",
|
|
||||||
file_extensions=[".txt"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
)
|
|
||||||
db_session.add(config)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
dispatched: list = []
|
|
||||||
|
|
||||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
|
||||||
dispatched.append((user_id, cfg.id))
|
|
||||||
|
|
||||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
|
||||||
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
|
||||||
patch("asyncio.create_task") as mock_create_task:
|
|
||||||
resp = client.post(
|
|
||||||
f"/api/v1/agents/{config.id}/run",
|
|
||||||
headers=auth_header("power"),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resp.status_code == 202
|
|
||||||
data = resp.json()
|
|
||||||
assert data["agent_id"] == config.id
|
|
||||||
assert data["status"] == "running"
|
|
||||||
assert data["agent_type"] == "local"
|
|
||||||
|
|
||||||
# Verify create_task was called (dispatching background run).
|
|
||||||
mock_create_task.assert_called_once()
|
|
||||||
430
tests/test_agent_runner_v2.py
Normal file
430
tests/test_agent_runner_v2.py
Normal file
@@ -0,0 +1,430 @@
|
|||||||
|
"""Tests for Local Agent V2 runner (Step 2).
|
||||||
|
|
||||||
|
Covers the unified per-file flow:
|
||||||
|
Phase A — detect + preprocess (Python, zero LLM)
|
||||||
|
Phase B — single LLM call with tools (classify + extract + create)
|
||||||
|
|
||||||
|
Fixture-based eval tests (2.1–2.7)
|
||||||
|
-----------------------------------
|
||||||
|
Cases are defined in tests/fixtures/agent_runner_v2/cases.yaml.
|
||||||
|
Email HTML files live in tests/fixtures/agent_runner_v2/data/.
|
||||||
|
Use --runner-dir to point at a custom folder (same structure required).
|
||||||
|
|
||||||
|
Unit tests (no LLM)
|
||||||
|
--------------------
|
||||||
|
2.8 items_created count → items_created == N create_* calls
|
||||||
|
2.9 Device offline → status=error
|
||||||
|
2.10 Empty file → items_processed=0, status=success
|
||||||
|
|
||||||
|
Run:
|
||||||
|
pytest tests/test_agent_runner_v2.py -v
|
||||||
|
pytest tests/test_agent_runner_v2.py -v -k "2_9 or 2_10 or 2_8" # unit only
|
||||||
|
pytest tests/test_agent_runner_v2.py -v -k "eval" # LLM evals only
|
||||||
|
pytest tests/test_agent_runner_v2.py -v --runner-dir /path/to/dir # custom fixtures
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.core.agent_runner import (
|
||||||
|
_format_metadata,
|
||||||
|
_format_projects,
|
||||||
|
_get_extraction_rules,
|
||||||
|
_get_no_match_behavior,
|
||||||
|
run_local_agent,
|
||||||
|
)
|
||||||
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
|
from app.core.langfuse_client import get_langfuse
|
||||||
|
from app.models import AgentRunLog, LocalAgentConfig
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
# ── Constants ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "agent_runner_v2"
|
||||||
|
|
||||||
|
_AGENT_CONFIG = {
|
||||||
|
"content_types": [
|
||||||
|
{
|
||||||
|
"id": "email_html",
|
||||||
|
"label": "Email HTML",
|
||||||
|
"detection_hint": "HTML file with From/To/Subject headers",
|
||||||
|
"preprocessing": "email_html",
|
||||||
|
"extraction_prompt": (
|
||||||
|
"If the email contains a direct action request or task assignment → create a task. "
|
||||||
|
"If the email contains informational content, updates, or FYI → create a note. "
|
||||||
|
"If the email mentions a specific date for a meeting or deadline → create a timeline entry."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"global_rules": [
|
||||||
|
"Se il file non è riconducibile a nessun progetto, non creare alcuna entità."
|
||||||
|
],
|
||||||
|
"data_types": ["tasks", "notes", "timelines"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Canonical project definitions, referenced symbolically in cases.yaml.
|
||||||
|
_PROJECTS: dict[str, dict] = {
|
||||||
|
"alpha": {"id": "proj-alpha", "name": "Project Alpha", "status": "active"},
|
||||||
|
"beta": {"id": "proj-beta", "name": "Project Beta", "status": "active"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixture loading ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _fixtures_dir(config) -> Path:
|
||||||
|
override = config.getoption("--runner-dir")
|
||||||
|
return Path(override) if override else _DEFAULT_FIXTURE_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cases(config) -> list[dict]:
|
||||||
|
return yaml.safe_load(
|
||||||
|
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_case_file(case: dict, data_dir: Path) -> str:
|
||||||
|
return (data_dir / case["file"]).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_projects(entries: list[str | dict]) -> list[dict]:
|
||||||
|
"""Resolve project list from YAML: symbolic names and/or inline dicts."""
|
||||||
|
result = []
|
||||||
|
for entry in entries:
|
||||||
|
if isinstance(entry, str):
|
||||||
|
if entry in _PROJECTS:
|
||||||
|
result.append(_PROJECTS[entry])
|
||||||
|
elif isinstance(entry, dict):
|
||||||
|
result.append(entry)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── pytest_generate_tests — parametrize eval tests from YAML ─────────────
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "runner_case" not in metafunc.fixturenames:
|
||||||
|
return
|
||||||
|
cases = _load_cases(metafunc.config)
|
||||||
|
metafunc.parametrize("runner_case", cases, ids=[c["id"] for c in cases])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_config(
|
||||||
|
agent_config: dict | None = None,
|
||||||
|
directory: str = "/emails",
|
||||||
|
device_id: str = "dev-001",
|
||||||
|
) -> LocalAgentConfig:
|
||||||
|
return LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=_USER_ID,
|
||||||
|
device_id=device_id,
|
||||||
|
name="Test V2 Agent",
|
||||||
|
directory_paths=[directory],
|
||||||
|
data_types=["tasks", "notes", "timelines"],
|
||||||
|
prompt_template="",
|
||||||
|
agent_config=agent_config or _AGENT_CONFIG,
|
||||||
|
file_extensions=[".html", ".eml"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_run_log(agent_id: str) -> AgentRunLog:
|
||||||
|
return AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=_USER_ID,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_manager(online: bool = True) -> DeviceConnectionManager:
|
||||||
|
mgr = DeviceConnectionManager()
|
||||||
|
if online:
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
mgr.register(_USER_ID, "dev-001", ws)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
|
||||||
|
def _make_executor(
|
||||||
|
file_path: str,
|
||||||
|
file_content: str,
|
||||||
|
projects: list[dict] | None = None,
|
||||||
|
existing_tasks: list[dict] | None = None,
|
||||||
|
existing_notes: list[dict] | None = None,
|
||||||
|
existing_timelines: list[dict] | None = None,
|
||||||
|
) -> tuple[Any, list[dict]]:
|
||||||
|
"""Return (async_executor, captured_calls).
|
||||||
|
|
||||||
|
The executor handles all ``execute_on_client`` payloads:
|
||||||
|
directory listing, file reading, project/entity fetching, and CRUD.
|
||||||
|
"""
|
||||||
|
calls: list[dict] = []
|
||||||
|
_projects = projects if projects is not None else list(_PROJECTS.values())
|
||||||
|
|
||||||
|
async def _executor(payload: dict) -> dict:
|
||||||
|
action = payload.get("action", "")
|
||||||
|
table = payload.get("table", "")
|
||||||
|
data = payload.get("data") or {}
|
||||||
|
calls.append({"action": action, "table": table, "data": data})
|
||||||
|
|
||||||
|
if action == "list_directory":
|
||||||
|
return {"entries": [{"type": "file", "path": file_path}]}
|
||||||
|
|
||||||
|
if action == "get_file_metadata":
|
||||||
|
return {"modifiedAt": None}
|
||||||
|
|
||||||
|
if action == "read_file_content":
|
||||||
|
return {"content": file_content}
|
||||||
|
|
||||||
|
if action == "select":
|
||||||
|
if table == "projects":
|
||||||
|
return {"rows": _projects}
|
||||||
|
if table == "tasks":
|
||||||
|
return {"rows": existing_tasks or []}
|
||||||
|
if table == "notes":
|
||||||
|
return {"rows": existing_notes or []}
|
||||||
|
if table == "timelines":
|
||||||
|
return {"rows": existing_timelines or []}
|
||||||
|
return {"rows": []}
|
||||||
|
|
||||||
|
if action == "insert":
|
||||||
|
return {"row": {"id": str(uuid.uuid4()), **data}}
|
||||||
|
|
||||||
|
if action == "update":
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return _executor, calls
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: helper functions ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_projects_empty():
|
||||||
|
assert "(no projects" in _format_projects([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_projects_with_data():
|
||||||
|
result = _format_projects([_PROJECTS["alpha"]])
|
||||||
|
assert "proj-alpha" in result
|
||||||
|
assert "Project Alpha" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_metadata_empty():
|
||||||
|
assert _format_metadata({}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_metadata_email():
|
||||||
|
meta = {"subject": "Fix bug", "from": "boss@co.com", "date": "2026-04-07"}
|
||||||
|
result = _format_metadata(meta)
|
||||||
|
assert "Fix bug" in result
|
||||||
|
assert "boss@co.com" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_extraction_rules_match():
|
||||||
|
rules = _get_extraction_rules(_AGENT_CONFIG, "email_html")
|
||||||
|
assert "task" in rules.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_extraction_rules_fallback():
|
||||||
|
rules = _get_extraction_rules(_AGENT_CONFIG, "plain_text")
|
||||||
|
assert "extract" in rules.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_no_match_behavior_from_global_rules():
|
||||||
|
behavior = _get_no_match_behavior(_AGENT_CONFIG)
|
||||||
|
assert behavior # non-empty
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_no_match_behavior_default():
|
||||||
|
behavior = _get_no_match_behavior({})
|
||||||
|
assert "project" in behavior.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: 2.9 — device offline ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_2_9_device_offline():
|
||||||
|
"""2.9 No device online → status=error, no executor created."""
|
||||||
|
config = _make_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager(online=False)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_fin.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("not connected" in e for e in kwargs.get("errors", []))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: 2.10 — empty file ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_2_10_empty_file():
|
||||||
|
"""2.10 File with empty content → skipped, items_processed=0, success."""
|
||||||
|
config = _make_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
executor, calls = _make_executor(
|
||||||
|
file_path="/emails/empty.html",
|
||||||
|
file_content="",
|
||||||
|
projects=[_PROJECTS["alpha"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_fin.call_args
|
||||||
|
assert kwargs["items_processed"] == 0
|
||||||
|
assert kwargs["status"] == "success"
|
||||||
|
assert kwargs["items_created"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: 2.8 — items_created count ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_2_8_items_created_count():
|
||||||
|
"""2.8 items_created == number of create_* tool calls per run."""
|
||||||
|
config = _make_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
executor, _calls = _make_executor(
|
||||||
|
file_path="/emails/action.html",
|
||||||
|
file_content="<html><body><p>Fix the login bug in Project Alpha.</p></body></html>",
|
||||||
|
projects=[_PROJECTS["alpha"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_run_agent(*, _tool_calls_out=None, **kw) -> str:
|
||||||
|
if _tool_calls_out is not None:
|
||||||
|
_tool_calls_out.extend(["create_task", "create_note", "update_task"])
|
||||||
|
return "Done."
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
||||||
|
patch("app.core.agent_runner._run_agent_with_tools", side_effect=mock_run_agent), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_fin.call_args
|
||||||
|
# Only create_task + create_note count (not update_task).
|
||||||
|
assert kwargs["items_created"] == 2
|
||||||
|
assert kwargs["items_processed"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Eval: 2.1–2.7 — fixture-driven, real LLM + Langfuse scoring ──────────
|
||||||
|
#
|
||||||
|
# Cases loaded from tests/fixtures/agent_runner_v2/cases.yaml.
|
||||||
|
# Supported assertions (from YAML):
|
||||||
|
# expect_insert: <table> → at least 1 insert in that table
|
||||||
|
# expect_no_insert: true → zero inserts in any table
|
||||||
|
# expect_project_id: <id> → any insert carries this projectId
|
||||||
|
# expect_dedup: true → task inserts == 0 OR task updates >= 1
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.eval
|
||||||
|
async def test_eval_runner(runner_case, pytestconfig):
|
||||||
|
"""Parametrized eval test — one invocation per YAML case."""
|
||||||
|
case: dict = runner_case
|
||||||
|
data_dir = _fixtures_dir(pytestconfig) / "data"
|
||||||
|
file_content = _read_case_file(case, data_dir)
|
||||||
|
projects = _resolve_projects(case.get("projects", []))
|
||||||
|
|
||||||
|
config = _make_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
executor, calls = _make_executor(
|
||||||
|
file_path=case["file_path"],
|
||||||
|
file_content=file_content,
|
||||||
|
projects=projects,
|
||||||
|
existing_tasks=case.get("existing_tasks"),
|
||||||
|
existing_notes=case.get("existing_notes"),
|
||||||
|
existing_timelines=case.get("existing_timelines"),
|
||||||
|
)
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
obs_ctx = lf.start_as_current_observation(
|
||||||
|
name=f"eval-runner-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
|
||||||
|
metadata={"step": "2", "case_id": case["id"]},
|
||||||
|
) if lf else nullcontext()
|
||||||
|
|
||||||
|
with obs_ctx as obs:
|
||||||
|
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_fin.call_args
|
||||||
|
score, comment = _evaluate_case(case, calls, kwargs)
|
||||||
|
|
||||||
|
if obs is not None:
|
||||||
|
obs.score(
|
||||||
|
name=case.get("score_name", f"runner.case_{case['id']}"),
|
||||||
|
value=score,
|
||||||
|
comment=comment,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lf:
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
assert score == 1.0, f"[{case['id']}] {case.get('description', '')} — {comment}"
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate_case(case: dict, calls: list[dict], finalize_kwargs: dict) -> tuple[float, str]:
|
||||||
|
"""Return (score, comment) for a YAML case given the captured executor calls."""
|
||||||
|
inserts = [c for c in calls if c["action"] == "insert"]
|
||||||
|
|
||||||
|
if case.get("expect_no_insert"):
|
||||||
|
score = 1.0 if len(inserts) == 0 else 0.0
|
||||||
|
return score, f"inserts={len(inserts)} (expected 0)"
|
||||||
|
|
||||||
|
if "expect_insert" in case:
|
||||||
|
tables = case["expect_insert"]
|
||||||
|
if isinstance(tables, str):
|
||||||
|
tables = [tables]
|
||||||
|
missing = [t for t in tables if not any(c["table"] == t for c in inserts)]
|
||||||
|
score = 1.0 if not missing else 0.0
|
||||||
|
counts = {t: sum(1 for c in inserts if c["table"] == t) for t in tables}
|
||||||
|
return score, f"inserts={counts}" + (f" missing={missing}" if missing else "")
|
||||||
|
|
||||||
|
if "expect_project_id" in case:
|
||||||
|
expected_pid = case["expect_project_id"]
|
||||||
|
correct = any(c.get("data", {}).get("projectId") == expected_pid for c in inserts)
|
||||||
|
score = 1.0 if correct else 0.0
|
||||||
|
all_pids = [c.get("data", {}).get("projectId") for c in inserts]
|
||||||
|
return score, f"projectIds={all_pids} (expected {expected_pid!r})"
|
||||||
|
|
||||||
|
if case.get("expect_dedup"):
|
||||||
|
task_creates = [c for c in inserts if c["table"] == "tasks"]
|
||||||
|
task_updates = [c for c in calls if c["action"] == "update" and c["table"] == "tasks"]
|
||||||
|
score = 1.0 if len(task_creates) == 0 or len(task_updates) >= 1 else 0.0
|
||||||
|
return score, f"task_creates={len(task_creates)} task_updates={len(task_updates)}"
|
||||||
|
|
||||||
|
return 0.0, "no assertion defined in case"
|
||||||
@@ -1,243 +0,0 @@
|
|||||||
"""Tests for the Chatbot Journey endpoints.
|
|
||||||
|
|
||||||
Covers:
|
|
||||||
1. Start journey for local agent → session_id + first question, done=False
|
|
||||||
2. Start journey for cloud agent → contextual email-focused question
|
|
||||||
3. Start journey with existing agent_id → session seeded, first question returned
|
|
||||||
4. Start journey with non-existent agent_id → still succeeds (graceful fallback)
|
|
||||||
5. Message: continue conversation → done=False, follow-up question returned
|
|
||||||
6. Message: LLM wraps up → done=True + prompt_template extracted correctly
|
|
||||||
7. Message with max-turns nudge → no crash, returns response
|
|
||||||
8. Invalid session_id → 404
|
|
||||||
9. Expired session → 404
|
|
||||||
10. Session ownership: user B cannot access user A's session
|
|
||||||
11. No JWT on /start → 401
|
|
||||||
12. No JWT on /message → 401
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.routes.agent_setup import (
|
|
||||||
_SESSION_TTL_SECONDS,
|
|
||||||
_TEMPLATE_END,
|
|
||||||
_TEMPLATE_START,
|
|
||||||
_extract_template,
|
|
||||||
_sessions,
|
|
||||||
)
|
|
||||||
from app.models import LocalAgentConfig
|
|
||||||
from tests.conftest import TEST_USER_IDS, auth_header
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _start(client: TestClient, agent_type: str = "local", agent_id: str | None = None, tier: str = "power") -> dict:
|
|
||||||
body: dict = {"agent_type": agent_type}
|
|
||||||
if agent_id:
|
|
||||||
body["agent_id"] = agent_id
|
|
||||||
resp = client.post("/api/v1/agents/journey/start", json=body, headers=auth_header(tier))
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
def _message(client: TestClient, session_id: str, message: str, tier: str = "power") -> dict:
|
|
||||||
return client.post(
|
|
||||||
"/api/v1/agents/journey/message",
|
|
||||||
json={"session_id": session_id, "message": message},
|
|
||||||
headers=auth_header(tier),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: _extract_template ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_template_present():
|
|
||||||
text = f"Some preamble.\n{_TEMPLATE_START}\nExtract tasks from emails.\n{_TEMPLATE_END}\nTrailing text."
|
|
||||||
result = _extract_template(text)
|
|
||||||
assert result == "Extract tasks from emails."
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_template_absent():
|
|
||||||
assert _extract_template("No markers here.") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_template_empty_content():
|
|
||||||
text = f"{_TEMPLATE_START}\n{_TEMPLATE_END}"
|
|
||||||
assert _extract_template(text) is None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Start journey ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_local(client: TestClient):
|
|
||||||
resp = _start(client, agent_type="local")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert "session_id" in body
|
|
||||||
assert body["done"] is False
|
|
||||||
assert body["prompt_template"] is None
|
|
||||||
assert len(body["message"]) > 0
|
|
||||||
# Local question should be about files/directories
|
|
||||||
assert any(w in body["message"].lower() for w in ("file", "director", "document", "monitor"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_cloud(client: TestClient):
|
|
||||||
resp = _start(client, agent_type="cloud")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["done"] is False
|
|
||||||
# Cloud question should mention emails or messages
|
|
||||||
assert any(w in body["message"].lower() for w in ("email", "message", "communication"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_with_agent_id(client: TestClient, db_session: AsyncSession):
|
|
||||||
"""When agent_id is provided, session should be created even if agent doesn't exist."""
|
|
||||||
fake_agent_id = str(uuid.uuid4())
|
|
||||||
resp = _start(client, agent_type="local", agent_id=fake_agent_id)
|
|
||||||
# Should succeed gracefully even if the agent_id doesn't exist
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["done"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_with_existing_agent(client: TestClient, db_session: AsyncSession):
|
|
||||||
"""When a real local agent is provided, session is seeded with its prompt_template."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
user_id = TEST_USER_IDS["power"]
|
|
||||||
agent = LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
name="Test Agent",
|
|
||||||
device_id="device-1",
|
|
||||||
directory_paths=["/home/user/emails"],
|
|
||||||
data_types=["tasks"],
|
|
||||||
prompt_template="Extract tasks from .eml files.",
|
|
||||||
file_extensions=[".eml"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _seed():
|
|
||||||
db_session.add(agent)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(_seed())
|
|
||||||
|
|
||||||
resp = _start(client, agent_type="local", agent_id=agent.id)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["done"] is False
|
|
||||||
# The session should be stored
|
|
||||||
assert body["session_id"] in _sessions
|
|
||||||
|
|
||||||
|
|
||||||
def test_start_journey_requires_auth(client: TestClient):
|
|
||||||
resp = client.post("/api/v1/agents/journey/start", json={"agent_type": "local"})
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
# ── Message ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_continues_conversation(client: TestClient):
|
|
||||||
"""A mid-journey reply (no template markers) returns done=False."""
|
|
||||||
follow_up = "That looks good. Can you tell me more about priority rules?"
|
|
||||||
|
|
||||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
|
||||||
start_resp = _start(client, agent_type="local")
|
|
||||||
assert start_resp.status_code == 200
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
msg_resp = _message(client, session_id, "I have .eml and .txt files")
|
|
||||||
assert msg_resp.status_code == 200
|
|
||||||
body = msg_resp.json()
|
|
||||||
assert body["done"] is False
|
|
||||||
assert body["prompt_template"] is None
|
|
||||||
assert body["message"] == follow_up
|
|
||||||
assert body["session_id"] == session_id
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_produces_template(client: TestClient):
|
|
||||||
"""When the LLM includes PROMPT_TEMPLATE markers, done=True and prompt_template is set."""
|
|
||||||
final_template = "Extract tasks from email. Subject → title. 'urgent' → high priority."
|
|
||||||
llm_response = (
|
|
||||||
"Great, I have all the information I need.\n"
|
|
||||||
f"{_TEMPLATE_START}\n{final_template}\n{_TEMPLATE_END}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=llm_response)):
|
|
||||||
start_resp = _start(client, agent_type="cloud")
|
|
||||||
assert start_resp.status_code == 200
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
msg_resp = _message(client, session_id, "Only invoices from clients")
|
|
||||||
assert msg_resp.status_code == 200
|
|
||||||
body = msg_resp.json()
|
|
||||||
assert body["done"] is True
|
|
||||||
assert body["prompt_template"] == final_template
|
|
||||||
# Session should be cleaned up
|
|
||||||
assert session_id not in _sessions
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_invalid_session(client: TestClient):
|
|
||||||
resp = _message(client, "nonexistent-session-id", "hello")
|
|
||||||
assert resp.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_wrong_owner(client: TestClient):
|
|
||||||
"""User B cannot access user A's session."""
|
|
||||||
start_resp = _start(client, agent_type="local", tier="power")
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
# user with "pro" tier (different user_id) tries to send a message
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/agents/journey/message",
|
|
||||||
json={"session_id": session_id, "message": "hello"},
|
|
||||||
headers=auth_header("pro"), # different user
|
|
||||||
)
|
|
||||||
assert resp.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_expired_session(client: TestClient):
|
|
||||||
"""Expired sessions return 404."""
|
|
||||||
start_resp = _start(client, agent_type="local")
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
# Manually expire the session
|
|
||||||
_sessions[session_id].created_at = time.monotonic() - _SESSION_TTL_SECONDS - 1
|
|
||||||
|
|
||||||
resp = _message(client, session_id, "hello")
|
|
||||||
assert resp.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_requires_auth(client: TestClient):
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/agents/journey/message",
|
|
||||||
json={"session_id": "any", "message": "hello"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_max_turns_nudge(client: TestClient):
|
|
||||||
"""After _MAX_TURNS user messages, a system nudge is appended but no crash occurs."""
|
|
||||||
from app.api.routes.agent_setup import _MAX_TURNS
|
|
||||||
|
|
||||||
follow_up = "Tell me more about priority rules."
|
|
||||||
|
|
||||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
|
||||||
start_resp = _start(client, agent_type="local")
|
|
||||||
session_id = start_resp.json()["session_id"]
|
|
||||||
|
|
||||||
for i in range(_MAX_TURNS):
|
|
||||||
resp = _message(client, session_id, f"Answer {i + 1}")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
# While no template produced, session must still exist
|
|
||||||
if resp.json()["done"]:
|
|
||||||
break # LLM decided to wrap up early — also fine
|
|
||||||
@@ -1,416 +0,0 @@
|
|||||||
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
|
|
||||||
|
|
||||||
# ── Minimal concrete agent for testing ───────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _EchoAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "_echo"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Echo agent for tests"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
|
|
||||||
msg = AIMessage(content=content)
|
|
||||||
if tool_calls:
|
|
||||||
msg.tool_calls = tool_calls
|
|
||||||
else:
|
|
||||||
msg.tool_calls = []
|
|
||||||
return msg
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tool(name: str, return_value: Any) -> MagicMock:
|
|
||||||
t = MagicMock()
|
|
||||||
t.name = name
|
|
||||||
t.ainvoke = AsyncMock(return_value=return_value)
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
|
|
||||||
chunks = []
|
|
||||||
for tok in tokens:
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = tok
|
|
||||||
chunks.append(c)
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
|
|
||||||
tokens: list[str] = []
|
|
||||||
async for tok in agent._tool_loop_stream(llm, messages, tools):
|
|
||||||
tokens.append(tok)
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
# ── tool_results initialised ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_results_init():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: no tool calls ────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_no_tools():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
|
||||||
assert result == "Hello!"
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: with one tool call + result capture ──────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_captures_tool_results():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
# Mock execute_on_client to return structured data via the tool
|
|
||||||
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
|
|
||||||
|
|
||||||
async def fake_executor(payload: dict) -> dict:
|
|
||||||
return raw_result
|
|
||||||
|
|
||||||
# AIMessage with a tool call, then a final answer
|
|
||||||
tool_call_msg = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
final_msg = _make_ai_message("Here are your tasks.")
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
|
|
||||||
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
|
|
||||||
|
|
||||||
from app.core.ws_context import set_client_executor, clear_client_executor
|
|
||||||
set_client_executor(fake_executor)
|
|
||||||
try:
|
|
||||||
# Patch the tool to actually call execute_on_client
|
|
||||||
async def tool_side_effect(args: dict) -> str:
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
res = await execute_on_client(action="select", table="tasks")
|
|
||||||
rows = res.get("rows", [])
|
|
||||||
return "\n".join(r["title"] for r in rows)
|
|
||||||
|
|
||||||
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(
|
|
||||||
llm, [HumanMessage(content="list my tasks")], [mock_tool]
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
assert result == "Here are your tasks."
|
|
||||||
assert len(agent.tool_results) == 1
|
|
||||||
assert agent.tool_results[0] == raw_result
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: tool_results reset on each call ──────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_resets_tool_results():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
|
|
||||||
|
|
||||||
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: unknown tool name ────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_unknown_tool():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
# No known tools — model still calls a non-existent one; loop handles gracefully
|
|
||||||
tool_call_msg = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
final_msg = _make_ai_message("Handled.")
|
|
||||||
|
|
||||||
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
|
|
||||||
assert result == "Handled."
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_max_iter():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
always_tool = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
fallback = _make_ai_message("Fallback.")
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
# Returns tool_call_msg on every iteration
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=fallback)
|
|
||||||
|
|
||||||
mock_tool = _make_tool("t", "ok")
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
|
|
||||||
assert result == "Fallback."
|
|
||||||
assert llm_with_tools.ainvoke.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_no_tools_yields_tokens():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
|
|
||||||
no_tool_msg = _make_ai_message("irrelevant")
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
for tok in ["Hello", " ", "world"]:
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = tok
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
|
|
||||||
assert tokens == ["Hello", " ", "world"]
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: one tool call then streaming final ─────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_with_tool_call():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
|
|
||||||
|
|
||||||
async def fake_executor(payload: dict) -> dict:
|
|
||||||
return raw_result
|
|
||||||
|
|
||||||
tool_call_msg = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
# After tools run, ainvoke returns no more tool calls
|
|
||||||
no_more_tools_msg = _make_ai_message("Task found.")
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
for tok in ["Task", " ", "found."]:
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = tok
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
async def tool_side_effect(args: dict) -> str:
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
|
|
||||||
return res.get("row", {}).get("title", "")
|
|
||||||
|
|
||||||
mock_tool = _make_tool("get_task", "Deploy")
|
|
||||||
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
||||||
|
|
||||||
from app.core.ws_context import set_client_executor, clear_client_executor
|
|
||||||
set_client_executor(fake_executor)
|
|
||||||
try:
|
|
||||||
tokens = await _collect_stream(
|
|
||||||
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
assert tokens == ["Task", " ", "found."]
|
|
||||||
assert len(agent.tool_results) == 1
|
|
||||||
assert agent.tool_results[0] == raw_result
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: tool_results reset on each call ───────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_resets_tool_results():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
agent.tool_results = [{"old": True}]
|
|
||||||
|
|
||||||
no_tool_msg = _make_ai_message("")
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = "ok"
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
|
||||||
assert agent.tool_results == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_skips_empty_chunks():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
no_tool_msg = _make_ai_message("")
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
for tok in ["", "hello", "", " world", ""]:
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = tok
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
|
||||||
assert tokens == ["hello", " world"]
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_max_iter():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
always_tool = _make_ai_message(
|
|
||||||
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = "fallback"
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
mock_tool = _make_tool("t", "ok")
|
|
||||||
|
|
||||||
tokens = await _collect_stream(
|
|
||||||
agent, llm, [HumanMessage(content="x")], [mock_tool],
|
|
||||||
)
|
|
||||||
assert tokens == ["fallback"]
|
|
||||||
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
|
|
||||||
|
|
||||||
|
|
||||||
# ── _tool_loop_stream: multiple tool results captured ────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_loop_stream_multiple_tool_results():
|
|
||||||
agent = _EchoAgent()
|
|
||||||
|
|
||||||
call_results = [
|
|
||||||
{"rows": [{"id": "t-1"}]},
|
|
||||||
{"rows": [{"id": "t-2"}]},
|
|
||||||
]
|
|
||||||
call_iter = iter(call_results)
|
|
||||||
|
|
||||||
async def fake_executor(payload: dict) -> dict:
|
|
||||||
return next(call_iter)
|
|
||||||
|
|
||||||
# Two tool calls in one iteration
|
|
||||||
tool_call_msg = _make_ai_message(
|
|
||||||
tool_calls=[
|
|
||||||
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
|
|
||||||
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
no_more_tools_msg = _make_ai_message("Done.")
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm_with_tools = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
|
||||||
|
|
||||||
async def fake_astream(msgs):
|
|
||||||
c = MagicMock()
|
|
||||||
c.content = "Done."
|
|
||||||
yield c
|
|
||||||
|
|
||||||
llm.astream = fake_astream
|
|
||||||
|
|
||||||
async def tool_side_effect(args: dict) -> str:
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
res = await execute_on_client(action="select", table="tasks")
|
|
||||||
return str(res)
|
|
||||||
|
|
||||||
tool_a = _make_tool("tool_a", "")
|
|
||||||
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
||||||
tool_b = _make_tool("tool_b", "")
|
|
||||||
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
||||||
|
|
||||||
from app.core.ws_context import set_client_executor, clear_client_executor
|
|
||||||
set_client_executor(fake_executor)
|
|
||||||
try:
|
|
||||||
tokens = await _collect_stream(
|
|
||||||
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
assert tokens == ["Done."]
|
|
||||||
assert len(agent.tool_results) == 2
|
|
||||||
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
|
|
||||||
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}
|
|
||||||
@@ -1,761 +0,0 @@
|
|||||||
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
from app.agents.timeline_agent import TimelineAgent
|
|
||||||
from app.agents.note_agent import NoteAgent
|
|
||||||
from app.agents.project_agent import ProjectAgent
|
|
||||||
from app.agents.task_agent import TaskAgent
|
|
||||||
from app.core.agent_registry import registry
|
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
|
||||||
|
|
||||||
|
|
||||||
# ── WS executor mock ──────────────────────────────────────────────────
|
|
||||||
#
|
|
||||||
# Tools call execute_on_client() which reads a ContextVar set by the WS
|
|
||||||
# handler. In unit tests there is no WS session, so we install a fake
|
|
||||||
# executor that returns plausible data for each action type.
|
|
||||||
|
|
||||||
_FAKE_ROW: dict[str, Any] = {
|
|
||||||
"id": "fake-id",
|
|
||||||
"title": "Fake Title",
|
|
||||||
"name": "Fake Name",
|
|
||||||
"status": "todo",
|
|
||||||
"priority": "medium",
|
|
||||||
"content": "Fake content",
|
|
||||||
"date": 1700000000000,
|
|
||||||
"taskId": "fake-task-id",
|
|
||||||
"author": "Alice",
|
|
||||||
"projectId": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _fake_executor(payload: dict) -> dict:
|
|
||||||
action = payload.get("action", "")
|
|
||||||
if action == "select":
|
|
||||||
return {"rows": []}
|
|
||||||
if action == "insert":
|
|
||||||
data = payload.get("data", {})
|
|
||||||
return {"row": {**_FAKE_ROW, **data}}
|
|
||||||
if action == "update":
|
|
||||||
data = payload.get("data", {})
|
|
||||||
row = {**_FAKE_ROW, "id": data.get("id", "fake-id"), **data.get("updates", {})}
|
|
||||||
return {"row": row}
|
|
||||||
if action == "delete":
|
|
||||||
return {"deleted": True}
|
|
||||||
if action == "get":
|
|
||||||
data = payload.get("data", {})
|
|
||||||
return {"row": {**_FAKE_ROW, "id": data.get("id", "fake-id")}}
|
|
||||||
if action == "vector_upsert":
|
|
||||||
return {"ok": True}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def ws_executor():
|
|
||||||
"""Install a fake WS executor for every test so tools can run without a real WS."""
|
|
||||||
set_client_executor(_fake_executor)
|
|
||||||
yield
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm(response_text: str) -> MagicMock:
|
|
||||||
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.content = response_text
|
|
||||||
msg.tool_calls = []
|
|
||||||
llm = MagicMock()
|
|
||||||
bound = MagicMock()
|
|
||||||
bound.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
llm.bind_tools = MagicMock(return_value=bound)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm_with_tool_call(
|
|
||||||
tool_name: str, tool_args: dict[str, Any], final_text: str
|
|
||||||
) -> MagicMock:
|
|
||||||
"""Mock LLM that fires one tool call then returns *final_text*."""
|
|
||||||
tool_msg = MagicMock()
|
|
||||||
tool_msg.content = ""
|
|
||||||
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = final_text
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
bound = MagicMock()
|
|
||||||
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=bound)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
# ── Registration ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentRegistration:
|
|
||||||
def test_all_agents_registered(self) -> None:
|
|
||||||
names = {a["name"] for a in registry.list_agents()}
|
|
||||||
assert {
|
|
||||||
"task_agent", "timeline_agent", "project_agent", "note_agent"
|
|
||||||
}.issubset(names)
|
|
||||||
|
|
||||||
def test_registry_returns_correct_types(self) -> None:
|
|
||||||
assert isinstance(registry.get("task_agent"), TaskAgent)
|
|
||||||
assert isinstance(registry.get("timeline_agent"), TimelineAgent)
|
|
||||||
assert isinstance(registry.get("project_agent"), ProjectAgent)
|
|
||||||
assert isinstance(registry.get("note_agent"), NoteAgent)
|
|
||||||
|
|
||||||
def test_descriptions_present(self) -> None:
|
|
||||||
for agent_info in registry.list_agents():
|
|
||||||
assert agent_info["description"], f"Empty description: {agent_info['name']}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── TaskAgent ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestTaskAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert TaskAgent().get_name() == "task_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(TaskAgent().get_tools()) == 8
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in TaskAgent().get_tools()}
|
|
||||||
assert names == {
|
|
||||||
"list_tasks",
|
|
||||||
"create_task",
|
|
||||||
"update_task",
|
|
||||||
"delete_task",
|
|
||||||
"list_tasks_due_today",
|
|
||||||
"list_task_comments",
|
|
||||||
"add_task_comment",
|
|
||||||
"delete_task_comment",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_returns_string(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Task created.")
|
|
||||||
result = await TaskAgent().handle("create a task", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Here are your tasks.")
|
|
||||||
result = await TaskAgent().handle("list my tasks", {})
|
|
||||||
assert result == "Here are your tasks."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_task_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_task",
|
|
||||||
{"title": "Buy groceries", "priority": "low"},
|
|
||||||
"Task 'Buy groceries' created.",
|
|
||||||
)
|
|
||||||
result = await TaskAgent().handle("add a grocery task", {})
|
|
||||||
assert result == "Task 'Buy groceries' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await TaskAgent().handle("help", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_rich_context(self) -> None:
|
|
||||||
context = {
|
|
||||||
"user_profile": {"id": "u1", "tier": "pro"},
|
|
||||||
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
|
||||||
}
|
|
||||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Tasks listed.")
|
|
||||||
result = await TaskAgent().handle("show tasks", context)
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTaskAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_defaults(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_tasks.ainvoke({})
|
|
||||||
m.assert_called_once_with(
|
|
||||||
action="select", table="tasks",
|
|
||||||
filters={"projectId": None, "status": None, "search": None, "orderBy": None},
|
|
||||||
)
|
|
||||||
assert result == "No tasks found matching the given filters."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_with_status_filter(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
await list_tasks.ainvoke({"status": "done"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["filters"]["status"] == "done"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_task_defaults(self) -> None:
|
|
||||||
from app.agents.task_agent import create_task
|
|
||||||
fake_row = {"id": "t1", "title": "Test task", "status": "todo", "priority": "medium"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await create_task.ainvoke({"title": "Test task"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "insert"
|
|
||||||
assert call_kwargs["table"] == "tasks"
|
|
||||||
assert call_kwargs["data"]["title"] == "Test task"
|
|
||||||
assert call_kwargs["data"]["status"] == "todo"
|
|
||||||
assert call_kwargs["data"]["priority"] == "medium"
|
|
||||||
assert "Test task" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_task_with_all_fields(self) -> None:
|
|
||||||
from app.agents.task_agent import create_task
|
|
||||||
fake_row = {"id": "t1", "title": "Deploy", "status": "in_progress", "priority": "high"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await create_task.ainvoke({
|
|
||||||
"title": "Deploy", "priority": "high", "status": "in_progress",
|
|
||||||
"project_id": "p1", "is_ai_suggested": 1,
|
|
||||||
})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["data"]["priority"] == "high"
|
|
||||||
assert call_kwargs["data"]["status"] == "in_progress"
|
|
||||||
assert call_kwargs["data"]["projectId"] == "p1"
|
|
||||||
assert call_kwargs["data"]["isAiSuggested"] == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_task_with_status(self) -> None:
|
|
||||||
from app.agents.task_agent import update_task
|
|
||||||
fake_row = {"id": "t1", "title": "Buy groceries", "status": "done"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "update"
|
|
||||||
assert call_kwargs["data"]["id"] == "t1"
|
|
||||||
assert call_kwargs["data"]["updates"]["status"] == "done"
|
|
||||||
assert "t1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_task_empty_updates(self) -> None:
|
|
||||||
from app.agents.task_agent import update_task
|
|
||||||
fake_row = {"id": "t1", "title": "Task", "status": "todo"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await update_task.ainvoke({"task_id": "t1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_task(self) -> None:
|
|
||||||
from app.agents.task_agent import delete_task
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_task.ainvoke({"task_id": "t1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["table"] == "tasks"
|
|
||||||
assert call_kwargs["data"]["id"] == "t1"
|
|
||||||
assert "t1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_due_today(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks_due_today
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_tasks_due_today.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "tasks"
|
|
||||||
assert "dueDateFrom" in call_kwargs["filters"]
|
|
||||||
assert result == "No tasks are due today."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_task_comments(self) -> None:
|
|
||||||
from app.agents.task_agent import list_task_comments
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "taskComments"
|
|
||||||
assert call_kwargs["filters"]["taskId"] == "t1"
|
|
||||||
assert "t1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_task_comment(self) -> None:
|
|
||||||
from app.agents.task_agent import add_task_comment
|
|
||||||
fake_row = {"id": "c1", "taskId": "t1", "author": "Alice", "content": "Looks good!"}
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await add_task_comment.ainvoke({
|
|
||||||
"task_id": "t1", "author": "Alice", "content": "Looks good!",
|
|
||||||
})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "insert"
|
|
||||||
assert call_kwargs["table"] == "taskComments"
|
|
||||||
assert call_kwargs["data"]["taskId"] == "t1"
|
|
||||||
assert call_kwargs["data"]["author"] == "Alice"
|
|
||||||
assert call_kwargs["data"]["content"] == "Looks good!"
|
|
||||||
assert "Alice" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_task_comment(self) -> None:
|
|
||||||
from app.agents.task_agent import delete_task_comment
|
|
||||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["table"] == "taskComments"
|
|
||||||
assert call_kwargs["data"]["id"] == "c1"
|
|
||||||
assert "c1" in result
|
|
||||||
|
|
||||||
|
|
||||||
# ── TimelineAgent ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestTimelineAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert TimelineAgent().get_name() == "timeline_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert TimelineAgent().get_description() == "Manages project timelines (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(TimelineAgent().get_tools()) == 4
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in TimelineAgent().get_tools()}
|
|
||||||
assert names == {"list_timelines", "create_timeline", "update_timeline", "delete_timeline"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("No timelines found.")
|
|
||||||
result = await TimelineAgent().handle("list timelines", {})
|
|
||||||
assert result == "No timelines found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_timeline",
|
|
||||||
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
|
||||||
"Timeline 'MVP Launch' created.",
|
|
||||||
)
|
|
||||||
result = await TimelineAgent().handle("add MVP timeline", {})
|
|
||||||
assert result == "Timeline 'MVP Launch' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await TimelineAgent().handle("show milestones", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTimelineAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_timelines_no_project(self) -> None:
|
|
||||||
from app.agents.timeline_agent import list_timelines
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_timelines.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "timelines"
|
|
||||||
assert call_kwargs["filters"]["projectId"] is None
|
|
||||||
assert result == "No timelines found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_timelines_with_project(self) -> None:
|
|
||||||
from app.agents.timeline_agent import list_timelines
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
await list_timelines.ainvoke({"project_id": "p1"})
|
|
||||||
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_timeline(self) -> None:
|
|
||||||
from app.agents.timeline_agent import create_timeline
|
|
||||||
fake_row = {"id": "cp1", "title": "Beta release", "date": 1700000000000}
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await create_timeline.ainvoke({
|
|
||||||
"project_id": "p1", "title": "Beta release", "date": 1700000000000,
|
|
||||||
})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "insert"
|
|
||||||
assert call_kwargs["table"] == "timelines"
|
|
||||||
assert call_kwargs["data"]["projectId"] == "p1"
|
|
||||||
assert call_kwargs["data"]["title"] == "Beta release"
|
|
||||||
assert call_kwargs["data"]["date"] == 1700000000000
|
|
||||||
assert "Beta release" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_timeline_ai_suggested(self) -> None:
|
|
||||||
from app.agents.timeline_agent import create_timeline
|
|
||||||
fake_row = {"id": "cp1", "title": "Review", "date": 1700000000000}
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await create_timeline.ainvoke({
|
|
||||||
"project_id": "p1", "title": "Review", "date": 1700000000000, "is_ai_suggested": 1,
|
|
||||||
})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["data"]["isAiSuggested"] == 1
|
|
||||||
assert call_kwargs["data"]["isApproved"] == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_timeline_approve(self) -> None:
|
|
||||||
from app.agents.timeline_agent import update_timeline
|
|
||||||
fake_row = {"id": "c1", "title": "MVP", "isApproved": 1}
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await update_timeline.ainvoke({"timeline_id": "c1", "is_approved": 1})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "update"
|
|
||||||
assert call_kwargs["data"]["id"] == "c1"
|
|
||||||
assert call_kwargs["data"]["updates"]["isApproved"] == 1
|
|
||||||
assert "c1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_timeline_empty_updates(self) -> None:
|
|
||||||
from app.agents.timeline_agent import update_timeline
|
|
||||||
fake_row = {"id": "c1", "title": "MVP"}
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await update_timeline.ainvoke({"timeline_id": "c1"})
|
|
||||||
assert m.call_args.kwargs["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_timeline(self) -> None:
|
|
||||||
from app.agents.timeline_agent import delete_timeline
|
|
||||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_timeline.ainvoke({"timeline_id": "c1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["table"] == "timelines"
|
|
||||||
assert call_kwargs["data"]["id"] == "c1"
|
|
||||||
assert "c1" in result
|
|
||||||
|
|
||||||
|
|
||||||
# ── ProjectAgent ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert ProjectAgent().get_name() == "project_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(ProjectAgent().get_tools()) == 6
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in ProjectAgent().get_tools()}
|
|
||||||
assert names == {
|
|
||||||
"list_projects",
|
|
||||||
"list_all_projects",
|
|
||||||
"get_project",
|
|
||||||
"create_project",
|
|
||||||
"update_project",
|
|
||||||
"delete_project",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
|
||||||
result = await ProjectAgent().handle("show my projects", {})
|
|
||||||
assert result == "Project Alpha is active."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_project_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_project",
|
|
||||||
{"name": "Pippo"},
|
|
||||||
"Project 'Pippo' created.",
|
|
||||||
)
|
|
||||||
result = await ProjectAgent().handle("create project Pippo", {})
|
|
||||||
assert result == "Project 'Pippo' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await ProjectAgent().handle("archive old project", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_projects_defaults(self) -> None:
|
|
||||||
from app.agents.project_agent import list_projects
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_projects.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "projects"
|
|
||||||
assert call_kwargs["filters"]["includeArchived"] is False
|
|
||||||
assert result == "No projects found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_projects_include_archived(self) -> None:
|
|
||||||
from app.agents.project_agent import list_projects
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
await list_projects.ainvoke({"include_archived": 1})
|
|
||||||
assert m.call_args.kwargs["filters"]["includeArchived"] is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_all_projects(self) -> None:
|
|
||||||
from app.agents.project_agent import list_all_projects
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_all_projects.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "projects"
|
|
||||||
assert result == "No projects found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_project(self) -> None:
|
|
||||||
from app.agents.project_agent import get_project
|
|
||||||
fake_row = {"id": "p1", "name": "Alpha", "status": "active", "clientId": None}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await get_project.ainvoke({"project_id": "p1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "get"
|
|
||||||
assert call_kwargs["table"] == "projects"
|
|
||||||
assert call_kwargs["data"]["id"] == "p1"
|
|
||||||
assert "Alpha" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_project_name_only(self) -> None:
|
|
||||||
from app.agents.project_agent import create_project
|
|
||||||
fake_row = {"id": "p1", "name": "Alpha"}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await create_project.ainvoke({"name": "Alpha"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "insert"
|
|
||||||
assert call_kwargs["data"]["name"] == "Alpha"
|
|
||||||
assert call_kwargs["data"]["clientId"] is None
|
|
||||||
assert "Alpha" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_project_with_client(self) -> None:
|
|
||||||
from app.agents.project_agent import create_project
|
|
||||||
fake_row = {"id": "p1", "name": "Beta"}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
|
||||||
assert m.call_args.kwargs["data"]["clientId"] == "cl1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_project_archive(self) -> None:
|
|
||||||
from app.agents.project_agent import update_project
|
|
||||||
fake_row = {"id": "p1", "name": "Alpha", "status": "archived"}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "update"
|
|
||||||
assert call_kwargs["data"]["id"] == "p1"
|
|
||||||
assert call_kwargs["data"]["updates"]["status"] == "archived"
|
|
||||||
assert "p1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_project_empty_updates(self) -> None:
|
|
||||||
from app.agents.project_agent import update_project
|
|
||||||
fake_row = {"id": "p1", "name": "Alpha", "status": "active"}
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await update_project.ainvoke({"project_id": "p1"})
|
|
||||||
assert m.call_args.kwargs["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_project(self) -> None:
|
|
||||||
from app.agents.project_agent import delete_project
|
|
||||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_project.ainvoke({"project_id": "p1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["data"]["id"] == "p1"
|
|
||||||
assert "p1" in result
|
|
||||||
|
|
||||||
|
|
||||||
# ── NoteAgent ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoteAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert NoteAgent().get_name() == "note_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(NoteAgent().get_tools()) == 5
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in NoteAgent().get_tools()}
|
|
||||||
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Note created.")
|
|
||||||
result = await NoteAgent().handle("create a note", {})
|
|
||||||
assert result == "Note created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_note_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_note",
|
|
||||||
{"title": "Daily log", "content": "# Today\nAll good."},
|
|
||||||
"Note 'Daily log' created.",
|
|
||||||
)
|
|
||||||
result = await NoteAgent().handle("log today's progress", {})
|
|
||||||
assert result == "Note 'Daily log' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await NoteAgent().handle("show notes", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoteAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_notes_no_project(self) -> None:
|
|
||||||
from app.agents.note_agent import list_notes
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
result = await list_notes.ainvoke({})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "select"
|
|
||||||
assert call_kwargs["table"] == "notes"
|
|
||||||
assert call_kwargs["filters"]["projectId"] is None
|
|
||||||
assert result == "No notes found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_notes_with_project(self) -> None:
|
|
||||||
from app.agents.note_agent import list_notes
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"rows": []}
|
|
||||||
await list_notes.ainvoke({"project_id": "p1"})
|
|
||||||
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_note(self) -> None:
|
|
||||||
from app.agents.note_agent import get_note
|
|
||||||
fake_row = {"id": "n1", "title": "Daily log", "content": "# Today\nAll good."}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
result = await get_note.ainvoke({"note_id": "n1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "get"
|
|
||||||
assert call_kwargs["table"] == "notes"
|
|
||||||
assert call_kwargs["data"]["id"] == "n1"
|
|
||||||
assert "Daily log" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_note_minimal(self) -> None:
|
|
||||||
from app.agents.note_agent import create_note
|
|
||||||
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
|
||||||
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
me.return_value = [0.0] * 1536
|
|
||||||
result = await create_note.ainvoke({"title": "Daily log", "content": "# Today\nAll good."})
|
|
||||||
# First call: insert; second call: vector_upsert
|
|
||||||
first_call = m.call_args_list[0].kwargs
|
|
||||||
assert first_call["action"] == "insert"
|
|
||||||
assert first_call["table"] == "notes"
|
|
||||||
assert first_call["data"]["title"] == "Daily log"
|
|
||||||
assert first_call["data"]["content"] == "# Today\nAll good."
|
|
||||||
assert first_call["data"]["projectId"] is None
|
|
||||||
assert "Daily log" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_note_with_project(self) -> None:
|
|
||||||
from app.agents.note_agent import create_note
|
|
||||||
fake_row = {"id": "n1", "title": "Sprint notes", "projectId": "p1"}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
|
||||||
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
me.return_value = [0.0] * 1536
|
|
||||||
await create_note.ainvoke({"title": "Sprint notes", "content": "## Sprint 1", "project_id": "p1"})
|
|
||||||
first_call = m.call_args_list[0].kwargs
|
|
||||||
assert first_call["data"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_note_content_only(self) -> None:
|
|
||||||
from app.agents.note_agent import update_note
|
|
||||||
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
|
||||||
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
me.return_value = [0.0] * 1536
|
|
||||||
result = await update_note.ainvoke({"note_id": "n1", "content": "# Updated content"})
|
|
||||||
first_call = m.call_args_list[0].kwargs
|
|
||||||
assert first_call["action"] == "update"
|
|
||||||
assert first_call["data"]["id"] == "n1"
|
|
||||||
assert first_call["data"]["updates"]["content"] == "# Updated content"
|
|
||||||
assert "title" not in first_call["data"]["updates"]
|
|
||||||
assert "n1" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_note_empty_updates(self) -> None:
|
|
||||||
from app.agents.note_agent import update_note
|
|
||||||
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"row": fake_row}
|
|
||||||
await update_note.ainvoke({"note_id": "n1"})
|
|
||||||
assert m.call_args.kwargs["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_note(self) -> None:
|
|
||||||
from app.agents.note_agent import delete_note
|
|
||||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
|
||||||
m.return_value = {"deleted": True}
|
|
||||||
result = await delete_note.ainvoke({"note_id": "n1"})
|
|
||||||
call_kwargs = m.call_args.kwargs
|
|
||||||
assert call_kwargs["action"] == "delete"
|
|
||||||
assert call_kwargs["table"] == "notes"
|
|
||||||
assert call_kwargs["data"]["id"] == "n1"
|
|
||||||
assert "n1" in result
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user