10 Commits

Author SHA1 Message Date
47bf1881e5 deep agent 2026-03-12 18:03:27 +01:00
24a9c1b752 refactor: migrate from create_react_agent to create_deep_agent
- Replace langgraph create_react_agent with deepagents create_deep_agent
- Sub-agents now configured as SubAgent dicts dispatched via built-in task tool
- Stream filter updated: langgraph_node 'agent' → 'model'
- Accept both AIMessage and AIMessageChunk in stream filter
- Collector only captures write mutations (insert/update/delete)
- Add deepagents>=0.4.10 to requirements.txt
2026-03-12 01:21:14 +01:00
706bf88883 fix KeyError 'JSON' — escape braces in chart prompt for str.format() 2026-03-12 00:44:41 +01:00
4ff0b27084 removed WsStreamBlock class 2026-03-12 00:38:31 +01:00
61d2a18234 remove WsStreamBlock schema and tests — no longer used 2026-03-12 00:37:40 +01:00
b3687719b6 add chart tag instructions to HOME system prompt 2026-03-12 00:28:43 +01:00
f80bdfa8f7 simplify HomeFormatter to pass-through — frontend handles entity tag parsing 2026-03-12 00:10:38 +01:00
617a17db40 feat: HomeFormatter parses inline entity tags instead of tool_end blocks
The supervisor LLM now embeds <type>[id1,id2]</type> entity tags in its
response text. The HomeFormatter buffers streamed tokens, detects complete
tags across chunk boundaries, and emits WsStreamBlock with entity type +
specific IDs. This replaces the old approach of emitting blocks for every
tool_end event, which dumped ALL entities regardless of relevance.

Also fixes:
- NoneType guard on metadata in _run_graph_stream (metadata can be None)
- Updated _HOME_SYSTEM prompt with entity tag instructions
- Updated all affected tests
2026-03-12 00:01:06 +01:00
92716cb89a fix: pass tool name as positional arg to @tool decorator
The langchain @tool decorator expects the name as the first positional
argument (name_or_callable), not as name= keyword argument.
2026-03-11 23:32:14 +01:00
cfc9d7a942 refactor: replace orchestrator with LangGraph deep-agent supervisors
- Add app/core/deep_agent.py with Home and Floating supervisor graphs
  using LangGraph create_react_agent (hierarchical pattern)
- Strip ChatAgent classes from all 4 agent files, keep @tool functions
- Rewrite output_formatter.py for event-based (token/tool_end/mutations) stream
- Update device_ws.py to use run_home_stream/run_floating_stream
- Rewrite chat.py REST route to use run_home
- Add update_core_memory tool to both supervisors
- Add langgraph>=0.3.0 to requirements.txt
- Remove orchestrator.py, execution_plan.py, agent_registry.py, plans.py
- Remove PlanAction, PlanStep, ExecutionPlan, execution_mode from schemas
- Update all affected tests to match new API
- Remove 6 deprecated test files for deleted modules
- Clean up stale docstrings referencing removed orchestrator
2026-03-11 17:50:22 +01:00
57 changed files with 5623 additions and 4732 deletions

View File

@@ -23,13 +23,21 @@ LLM_ROUTER_MODEL=gpt-4o-mini
STRIPE_SECRET_KEY= STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET= STRIPE_WEBHOOK_SECRET=
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
# ── Langfuse (leave empty to disable observability) ─────────────────────────── # ── Vector Store ──────────────────────────────────────────────────────────────
LANGFUSE_SECRET_KEY= # Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
LANGFUSE_PUBLIC_KEY= PINECONE_API_KEY=
# LANGFUSE_HOST=https://cloud.langfuse.com # EU (default) PINECONE_INDEX=adiuva
# LANGFUSE_HOST=https://us.cloud.langfuse.com # US QDRANT_URL=
# LANGFUSE_HOST=http://localhost:3000 # Self-hosted QDRANT_API_KEY=
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
# ── CORS ────────────────────────────────────────────────────────────────────── # ── CORS ──────────────────────────────────────────────────────────────────────
# Comma-separated list parsed by Settings (override default if needed) # Comma-separated list parsed by Settings (override default if needed)

1
.gitignore vendored
View File

@@ -21,7 +21,6 @@ env/
.pytest_cache/ .pytest_cache/
htmlcov/ htmlcov/
.coverage .coverage
tests/fixtures/private*/
# Docker # Docker
*.log *.log

298
README.md
View File

@@ -1,8 +1,8 @@
# Adiuva Cloud API # Adiuva Cloud API
**AI-powered project management backend with LLM orchestration and subscription billing.** **AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
--- ---
@@ -20,7 +20,9 @@ Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe
- [AI Agent System](#ai-agent-system) - [AI Agent System](#ai-agent-system)
- [Orchestration & Execution Plans](#orchestration--execution-plans) - [Orchestration & Execution Plans](#orchestration--execution-plans)
- [Middleware](#middleware) - [Middleware](#middleware)
- [Storage Layer](#storage-layer)
- [Billing & Tiers](#billing--tiers) - [Billing & Tiers](#billing--tiers)
- [Plugin Marketplace](#plugin-marketplace)
- [Testing](#testing) - [Testing](#testing)
- [Project Structure](#project-structure) - [Project Structure](#project-structure)
- [License](#license) - [License](#license)
@@ -29,13 +31,15 @@ Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe
## Overview ## Overview
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, text embedding generation, and Stripe-based subscription billing across four tiers. Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
### Design Principles ### Design Principles
1. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments. 1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
2. **Stateless request handling** — all context comes from the client and JWT; no server-side session state. 2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
3. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values. 3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
--- ---
@@ -50,26 +54,27 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
│ ┌──────────────────┐ ┌────────────────────────────┐ │ │ ┌──────────────────┐ ┌────────────────────────────┐ │
│ │ Auth Routes │ │ Chat Routes │ │ │ │ Auth Routes │ │ Chat Routes │ │
│ │ Billing Routes │ │ ↓ │ │ │ │ Billing Routes │ │ ↓ │ │
│ │ Agent Routes │ │ Orchestrator (GPT-4o-mini)│ │ │ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
│ │ Device WS │ │ ↓ classify intent │ │ │ │ Backup Routes │ │ ↓ classify intent │ │
└──────────────────┘ │ Agent Registry │ │ │ Plugin Routes │ │ Agent Registry │ │
│ ↓ │ │ │ Vector Routes │ ↓ │ │
│ TaskAgent | ProjectAgent │ │ │ Plans Routes │ TaskAgent | ProjectAgent │ │
│ NoteAgent | CheckptAgent │ │ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
│ │ (GPT-4o + LangChain) │ │ │ │ (GPT-4o + LangChain) │ │
│ └────────────────────────────┘ │ │ └────────────────────────────┘ │
└────────────────────────────────────────────────────────┘ └────────────────────────────────────────────────────────┘
│ │
┌────────▼───┐ ┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
│ PostgreSQL │ │ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
│ (Auth, │ │ (Auth, │ │ (E2E blobs, │ │ Qdrant │
│ Billing, │ │ Billing, │ │ backups) │ │ (Vectors) │
Agents) Metadata) └───────────────┘ └────────────────┘
└────────────┘ └────────────┘
┌────────▼───┐ ┌────────▼───┐
│ Stripe │ │ Stripe │
│ (Billing) │ (Billing,
│ Connect) │
└────────────┘ └────────────┘
``` ```
@@ -80,14 +85,18 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent. 1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain. 2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts. 3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
4. **Text embeddings** — Generates text-embedding-3-small vectors for local client-side note search. 4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
5. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling. 5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
6. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation. 6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
7. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses. 7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
8. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier. 8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
9. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery. 9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
10. **Alembic migrations**Versioned schema management. 10. **Prompt IP protection**Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
11. **Comprehensive test suite** — In-memory SQLite, per-tier test fixtures, and full API coverage without external dependencies. 11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
--- ---
@@ -105,6 +114,7 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration | | `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding | | `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
| `stripe` | ≥ 11.0.0 | Billing and payment integration | | `stripe` | ≥ 11.0.0 | Billing and payment integration |
| `boto3` | ≥ 1.35.0 | AWS S3 client |
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities | | `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder | | `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver | | `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
@@ -114,9 +124,12 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) | | `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
| `websockets` | ≥ 14.0 | WebSocket protocol support | | `websockets` | ≥ 14.0 | WebSocket protocol support |
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) | | `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
| `pytest` | ≥ 8.0.0 | Test framework | | `pytest` | ≥ 8.0.0 | Test framework |
| `pytest-asyncio` | ≥ 0.24.0 | Async test support | | `pytest-asyncio` | ≥ 0.24.0 | Async test support |
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests | | `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
| `ruff` | ≥ 0.8.0 | Linter and formatter | | `ruff` | ≥ 0.8.0 | Linter and formatter |
--- ---
@@ -129,6 +142,7 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
- PostgreSQL 16+ - PostgreSQL 16+
- An OpenAI API key (for LLM features) - An OpenAI API key (for LLM features)
- Stripe API keys (optional — billing stubs gracefully when unconfigured) - Stripe API keys (optional — billing stubs gracefully when unconfigured)
- AWS credentials (optional — needed for S3 storage in production)
### Installation ### Installation
@@ -180,6 +194,11 @@ This starts two services:
- **app** — FastAPI server on port `8000` - **app** — FastAPI server on port `8000`
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks - **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
The compose file also includes optional services for fully local deployments:
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
### Dockerfile Details ### Dockerfile Details
The Dockerfile uses a multi-stage build: The Dockerfile uses a multi-stage build:
@@ -197,7 +216,7 @@ gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0
## Homelab / Self-Hosted Deployment ## Homelab / Self-Hosted Deployment
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
### 1. Start all services ### 1. Start all services
@@ -205,14 +224,35 @@ You can run the entire stack locally on a homelab with **no cloud dependencies e
docker compose up -d docker compose up -d
``` ```
This starts PostgreSQL alongside the app. This starts PostgreSQL, MinIO, and Qdrant alongside the app.
### 2. Configure your `.env` ### 2. Create the MinIO bucket
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
```bash
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
docker compose exec minio mc mb local/adiuva
```
### 3. Configure your `.env`
```bash ```bash
# Database (uses the compose PostgreSQL) # Database (uses the compose PostgreSQL)
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
# S3 → MinIO
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=http://minio:9000
AWS_ACCESS_KEY_ID=minioadmin
AWS_SECRET_ACCESS_KEY=minioadmin
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
QDRANT_URL=http://qdrant:6333
QDRANT_API_KEY=
PINECONE_API_KEY=
# Billing — leave empty to stub (no Stripe needed) # Billing — leave empty to stub (no Stripe needed)
STRIPE_SECRET_KEY= STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET= STRIPE_WEBHOOK_SECRET=
@@ -227,7 +267,7 @@ JWT_SECRET=your-secret-here
ENV=dev ENV=dev
``` ```
### 3. Run migrations ### 4. Run migrations
```bash ```bash
docker compose exec app alembic upgrade head docker compose exec app alembic upgrade head
@@ -238,7 +278,9 @@ docker compose exec app alembic upgrade head
| Service | Runs on | Port | Notes | | Service | Runs on | Port | Notes |
|---|---|---|---| |---|---|---|---|
| FastAPI app | Docker | 8000 | API server | | FastAPI app | Docker | 8000 | API server |
| PostgreSQL | Docker | 5432 | Auth, billing, agents | | PostgreSQL | Docker | 5432 | Auth, billing, metadata |
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
| Stripe | — | — | Stubbed when keys are empty | | Stripe | — | — | Stubbed when keys are empty |
| OpenAI / LLM | Cloud | — | Only external dependency | | OpenAI / LLM | Cloud | — | Only external dependency |
@@ -258,7 +300,17 @@ All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live | | `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live | | `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) | | `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
| `STRIPE_WEBHOOK_SECRET` | `str` | `\"\"` | Stripe webhook signature secret |\n| `OPENAI_API_KEY` | `str` | `\"\"` | OpenAI key for LLM agent calls | | `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
| `S3_REGION` | `str` | `us-east-1` | AWS region |
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) | | `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing | | `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins | | `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
@@ -290,7 +342,6 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
| Method | Path | Auth | Description | | Method | Path | Auth | Description |
|---|---|---|---| |---|---|---|---|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode | | `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
| `POST` | `/api/v1/chat/embed` | JWT | Generate a 1536-dim text embedding vector (`text-embedding-3-small`). Used by Electron for local note search. |
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. | | `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
### Plans ### Plans
@@ -300,6 +351,42 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks | | `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID | | `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
### Storage (Cloud Records)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
### Vectors (Cloud Vector Store)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
### Backup
| Method | Path | Auth | Description |
|---|---|---|---|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
### Plugins (Marketplace)
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
### Billing ### Billing
| Method | Path | Auth | Description | | Method | Path | Auth | Description |
@@ -313,7 +400,7 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
## Data Model ## Data Model
3 tables managed by Alembic migrations. Source: `app/models.py` 9 tables managed by Alembic migrations. Source: `app/models.py`
### Tables ### Tables
@@ -322,18 +409,27 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts | | `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation | | `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records | | `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
### Enum Types ### Enum Types
| Enum | Values | | Enum | Values |
|---|---| |---|---|
| `billing_tier` | `free`, `pro`, `power`, `team` | | `billing_tier` | `free`, `pro`, `power`, `team` |
| `plugin_status` | `pending_review`, `approved`, `rejected` |
| `review_decision` | `approved`, `rejected` |
### Migrations ### Migrations
| Version | Description | | Version | Description |
|---|---| |---|---|
| `001_initial_schema` | Creates core auth and billing tables with indexes and foreign key constraints | | `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
--- ---
@@ -343,7 +439,7 @@ The agent system uses a registry pattern with LangChain tool-calling agents powe
### Architecture ### Architecture
- **`BaseAgent`** — Abstract base with `user_id` and `shared_memory`. - **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling. - **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`. - **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
@@ -458,6 +554,39 @@ Source: `app/api/middleware/sanitizer.py`
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`. - Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints. - Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
- Logs sanitization events as `WARNING`. - Logs sanitization events as `WARNING`.
- Binary responses (storage, backup) are never touched.
---
## Storage Layer
### Blob Store
Source: `app/storage/blob_store.py`
- S3-backed storage for E2E encrypted blobs.
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
- The backend **never inspects or decrypts blob content**.
### Vector Store
Source: `app/storage/vector_store.py`
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
- Methods: `upsert()`, `search()`, `delete()`
### Encryption Utilities
Source: `app/storage/encryption.py`
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
- **No decryption key ever reaches the backend.**
--- ---
@@ -471,8 +600,11 @@ Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|---|---|---|---|---| |---|---|---|---|---|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited | | AI Agents | 3 | Unlimited | Unlimited | Unlimited |
| Batch Active | 2 | 10 | Unlimited | Unlimited | | Batch Active | 2 | 10 | Unlimited | Unlimited |
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited | | LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
| Batch Builder | — | — | ✓ | ✓ | | Batch Builder | — | — | ✓ | ✓ |
| Plugin Marketplace | — | — | ✓ | ✓ |
| SSO | — | — | — | ✓ | | SSO | — | — | — | ✓ |
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min | | Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
@@ -488,6 +620,47 @@ Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
- `get_tier(user_id)` — Returns the user's current billing tier. - `get_tier(user_id)` — Returns the user's current billing tier.
- `check_feature(tier, feature)` — Boolean feature gate check. - `check_feature(tier, feature)` — Boolean feature gate check.
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available. - `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
---
## Plugin Marketplace
Source: `app/marketplace/`
### Plugin Registry
- PostgreSQL-backed catalog of submitted and approved plugins.
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
### Review Queue
- Automated security checklist before human review:
- Plugin ID must match `^[a-z0-9-]+$`
- Permissions must be from the allowed set only
- No binary blobs in the manifest
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
- `get_pending(db)` — Lists plugins awaiting review.
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
### Revenue Sharing
- **70% developer / 30% platform** split on all paid plugin sales.
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
- Gracefully stubs transfers when Stripe is not configured.
### Seed Plugins
| Plugin | Category | Price |
|---|---|---|
| GitHub Sync | Productivity | Free |
| Slack Notifier | Communication | €4.99 |
| Time Tracker | Productivity | €9.99 |
--- ---
@@ -509,8 +682,10 @@ pytest -v
### Test Infrastructure ### Test Infrastructure
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed. - **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens. - **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test. - **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`. - **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
- **No external dependencies** — all tests run fully offline. - **No external dependencies** — all tests run fully offline.
@@ -519,6 +694,13 @@ pytest -v
| File | Coverage | | File | Coverage |
|---|---| |---|---|
| `test_auth.py` | Register, login, token access, refresh, expiration | | `test_auth.py` | Register, login, token access, refresh, expiration |
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection | | `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
--- ---
@@ -528,6 +710,7 @@ pytest -v
``` ```
adiuva-api/ adiuva-api/
├── alembic.ini # Alembic configuration ├── alembic.ini # Alembic configuration
├── BACKEND_PLAN.md # Architecture & design decisions
├── docker-compose.yml # Docker Compose (app + PostgreSQL) ├── docker-compose.yml # Docker Compose (app + PostgreSQL)
├── Dockerfile # Multi-stage production build ├── Dockerfile # Multi-stage production build
├── requirements.txt # Python dependencies ├── requirements.txt # Python dependencies
@@ -536,12 +719,13 @@ adiuva-api/
│ ├── env.py # Alembic environment config │ ├── env.py # Alembic environment config
│ ├── script.py.mako # Migration template │ ├── script.py.mako # Migration template
│ └── versions/ │ └── versions/
── 001_initial_schema.py # Tables, indexes, FKs ── 001_initial_schema.py # Tables, indexes, FKs
│ └── 002_seed_plugins.py # Seed marketplace plugins
├── app/ # Application source ├── app/ # Application source
│ ├── main.py # FastAPI app factory, middleware, routes │ ├── main.py # FastAPI app factory, middleware, routes
│ ├── db.py # Async SQLAlchemy engine & session │ ├── db.py # Async SQLAlchemy engine & session
│ ├── models.py # SQLAlchemy ORM models │ ├── models.py # SQLAlchemy ORM models (9 tables)
│ ├── schemas.py # Pydantic request/response schemas │ ├── schemas.py # Pydantic request/response schemas
│ │ │ │
│ ├── config/ │ ├── config/
@@ -556,29 +740,47 @@ adiuva-api/
│ ├── core/ # Orchestration engine │ ├── core/ # Orchestration engine
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry │ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm) │ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
│ │ ── deep_agent.py # Deep agent orchestration │ │ ── orchestrator.py # Intent classification & routing
│ │ └── execution_plan.py # Plan builder, templates, cache
│ │ │ │
│ ├── api/ # HTTP layer │ ├── api/ # HTTP layer
│ │ ├── deps.py # Shared FastAPI dependencies │ │ ├── deps.py # Shared FastAPI dependencies
│ │ ├── middleware/ │ │ ├── middleware/
│ │ │ ├── auth.py # JWT validation, live tier lookup
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter │ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
│ │ │ └── sanitizer.py # Prompt IP leak protection │ │ │ └── sanitizer.py # Prompt IP leak protection
│ │ └── routes/ │ │ └── routes/
│ │ ├── auth.py # Register, login, refresh, me │ │ ├── auth.py # Register, login, refresh, me
│ │ ├── chat.py # Chat + embed endpoint │ │ ├── chat.py # Chat + WebSocket streaming
│ │ ├── billing.py # Stripe checkout, webhooks, subscription │ │ ├── plans.py # Execution plan playbooks
│ │ ├── agents.py # Agent catalog, config, runs │ │ ├── storage.py # E2E encrypted record CRUD
│ │ ── device_ws.py # Persistent device WebSocket │ │ ── vectors.py # Vector upsert, search, delete
│ │ ├── backup.py # Encrypted backup management
│ │ ├── plugins.py # Marketplace browse & install
│ │ └── billing.py # Stripe checkout & webhooks
│ │ │ │
── billing/ ── storage/ # Storage backends
├── stripe_service.py # Stripe API wrapper ├── blob_store.py # S3 blob storage
── tier_manager.py # Feature matrix, rate limits ── vector_store.py # Pinecone / Qdrant vector store
│ │ └── encryption.py # Checksum verification utilities
│ │
│ ├── billing/ # Subscription management
│ │ ├── stripe_service.py # Stripe API integration
│ │ └── tier_manager.py # Feature matrix & quota enforcement
│ │
│ └── marketplace/ # Plugin ecosystem
│ ├── plugin_registry.py # Catalog CRUD & search
│ ├── plugin_review.py # Security checklist & review queue
│ └── revenue_share.py # 70/30 split & Stripe Connect
└── tests/ # Test suite └── tests/ # Test suite
├── conftest.py # Fixtures: DB, auth, seeds ├── conftest.py # Fixtures: DB, S3, auth, seeds
├── test_auth.py ├── test_auth.py
├── test_orchestrator.py ├── test_orchestrator.py
├── test_agents.py ├── test_agents.py
├── test_storage.py
├── test_backup.py
├── test_plugins.py
├── test_agent_registry.py ├── test_agent_registry.py
├── test_execution_plan.py ├── test_execution_plan.py
└── test_middleware.py └── test_middleware.py

View File

@@ -1,4 +1,5 @@
"""Initial schema: users, refresh_tokens, subscriptions. """Initial schema: users, refresh_tokens, subscriptions, storage_records,
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
Revision ID: 001 Revision ID: 001
Revises: Revises:
@@ -27,6 +28,18 @@ def upgrade() -> None:
EXCEPTION WHEN duplicate_object THEN NULL; EXCEPTION WHEN duplicate_object THEN NULL;
END $$; END $$;
""") """)
op.execute("""
DO $$ BEGIN
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
# ── users ───────────────────────────────────────────────────────────── # ── users ─────────────────────────────────────────────────────────────
op.create_table( op.create_table(
@@ -75,10 +88,122 @@ def upgrade() -> None:
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"]) op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"]) op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
# ── storage_records ───────────────────────────────────────────────────
op.create_table(
"storage_records",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("table_name", sa.String(100), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
# ── backup_metadata ───────────────────────────────────────────────────
op.create_table(
"backup_metadata",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column("timestamp", sa.BigInteger, nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
# ── plugins ───────────────────────────────────────────────────────────
op.create_table(
"plugins",
sa.Column("id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("description", sa.Text, nullable=False, server_default=""),
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
sa.Column("category", sa.String(100), nullable=False, server_default=""),
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
sa.Column("s3_package_key", sa.String(500), nullable=True),
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
sa.Column("rejection_reason", sa.Text, nullable=True),
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
)
# ── plugin_installations ──────────────────────────────────────────────
op.create_table(
"plugin_installations",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
)
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
# ── plugin_reviews ────────────────────────────────────────────────────
op.create_table(
"plugin_reviews",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
sa.Column("notes", sa.Text, nullable=True),
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
)
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
# ── revenue_events ────────────────────────────────────────────────────
op.create_table(
"revenue_events",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
def downgrade() -> None: def downgrade() -> None:
op.drop_table("revenue_events")
op.drop_table("plugin_reviews")
op.drop_table("plugin_installations")
op.drop_table("plugins")
op.drop_table("backup_metadata")
op.drop_table("storage_records")
op.drop_table("subscriptions") op.drop_table("subscriptions")
op.drop_table("refresh_tokens") op.drop_table("refresh_tokens")
op.drop_table("users") op.drop_table("users")
op.execute("DROP TYPE IF EXISTS review_decision")
op.execute("DROP TYPE IF EXISTS plugin_status")
op.execute("DROP TYPE IF EXISTS billing_tier") op.execute("DROP TYPE IF EXISTS billing_tier")

View File

@@ -0,0 +1,92 @@
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
Revision ID: 002
Revises: 001
Create Date: 2026-03-03
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_SEED_PLUGINS = [
{
"id": "plugin-github-sync",
"name": "GitHub Sync",
"description": "Sync tasks with GitHub Issues and pull requests.",
"version": "1.0.0",
"author_name": "Adiuva",
"category": "productivity",
"price_cents": 0,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-slack-notify",
"name": "Slack Notifier",
"description": "Post task and timeline updates to Slack channels.",
"version": "1.2.0",
"author_name": "Adiuva",
"category": "communication",
"price_cents": 499,
"permissions": json.dumps(["read:tasks", "read:timelines"]),
"status": "approved",
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-time-tracker",
"name": "Time Tracker",
"description": "Track time spent on tasks with automatic reporting.",
"version": "0.9.1",
"author_name": "Third Party",
"category": "productivity",
"price_cents": 999,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
]
def upgrade() -> None:
plugins = sa.table(
"plugins",
sa.column("id", sa.String),
sa.column("name", sa.String),
sa.column("description", sa.Text),
sa.column("version", sa.String),
sa.column("author_name", sa.String),
sa.column("category", sa.String),
sa.column("price_cents", sa.Integer),
sa.column("permissions", sa.Text),
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
sa.column("s3_package_key", sa.String),
sa.column("install_count", sa.Integer),
sa.column("avg_rating", sa.Float),
)
op.bulk_insert(plugins, _SEED_PLUGINS)
def downgrade() -> None:
op.execute(
"DELETE FROM plugins WHERE id IN ("
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
")"
)

View File

@@ -1,92 +0,0 @@
"""Deprecate backend agent config tables.
The Electron client is now the source of truth for agent configuration
(directory, extract targets, batch interval, custom prompt). Backend keeps
billing checks and trigger/run logs only.
Revision ID: 9a1f2d0b6c7e
Revises: 818478c251dc
Create Date: 2026-03-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "9a1f2d0b6c7e"
down_revision: Union[str, None] = "818478c251dc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
existing = set(inspector.get_table_names())
if "cloud_agent_configs" in existing:
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
op.drop_table("cloud_agent_configs")
if "local_agent_configs" in existing:
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
op.drop_table("local_agent_configs")
def downgrade() -> None:
op.create_table(
"local_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("device_id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
op.execute(
"""
DO $$ BEGIN
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
"""
)
op.create_table(
"cloud_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column(
"provider",
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
nullable=False,
),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
sa.Column("filter_config", sa.JSON, nullable=True),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])

View File

@@ -1,5 +1,5 @@
"""Expose tool modules used by deep orchestrator-worker graphs.""" """Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent from app.agents import timeline_agent, note_agent, project_agent, task_agent
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"] __all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]

View File

@@ -1,85 +0,0 @@
"""Filesystem agent — tools for reading local directories and files on Electron.
These tools delegate to the Electron client via ``execute_on_client()`` using
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
handles actual disk I/O and responds with ``tool_result`` frames.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
@tool
async def list_directory(path: str) -> str:
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
result = await execute_on_client(
action="list_directory",
data={"path": path},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{path}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str:
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
result = await execute_on_client(
action="read_file_content",
data={"path": path},
)
content: str = result.get("content", "")
if not content:
return f"File '{path}' is empty or could not be read."
return content
@tool
async def get_file_metadata(path: str) -> str:
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
result = await execute_on_client(
action="get_file_metadata",
data={"path": path},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", path)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
FILESYSTEM_TOOLS: list[Any] = [
list_directory,
read_file_content,
get_file_metadata,
]

View File

@@ -1,8 +1,7 @@
"""Note agent — Markdown note management (list, get, create, update, delete).""" """Note agent — tool definitions for Markdown note CRUD."""
from __future__ import annotations from __future__ import annotations
import re
from typing import Any from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
@@ -10,38 +9,14 @@ from langchain_core.tools import tool
from app.core.llm import embed from app.core.llm import embed
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool @tool
async def list_notes(project_id: str = "") -> str: async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id.""" """List notes, optionally scoped to a project by project_id."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( result = await execute_on_client(
action="select", action="select",
table="notes", table="notes",
filters={"projectId": normalized_project_id or None}, filters={"projectId": project_id or None},
) )
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
@@ -130,10 +105,4 @@ async def delete_note(note_id: str) -> str:
return f"Note {note_id} deleted." return f"Note {note_id} deleted."
NOTE_TOOLS: list[Any] = [
list_notes,
get_note,
create_note,
update_note,
delete_note,
]

View File

@@ -1,4 +1,4 @@
"""Project agent — full lifecycle management (list, get, create, update, archive, delete).""" """Project agent — tool definitions for project lifecycle CRUD."""
from __future__ import annotations from __future__ import annotations
@@ -8,22 +8,6 @@ from langchain_core.tools import tool
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
PROJECT_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool @tool
async def list_projects( async def list_projects(
@@ -133,11 +117,4 @@ async def delete_project(project_id: str) -> str:
return f"Project {project_id} permanently deleted." return f"Project {project_id} permanently deleted."
PROJECT_TOOLS: list[Any] = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]

View File

@@ -1,40 +1,14 @@
"""Task agent — full CRUD for tasks and task comments.""" """Task agent — tool definitions for task and task comment CRUD."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
import re
from typing import Any from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ──────────────────────────────────────────────────────── # ── Task tools ────────────────────────────────────────────────────────
@@ -48,12 +22,11 @@ async def list_tasks(
) -> str: ) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done), """List tasks, optionally filtered by project_id, status (todo|in_progress|done),
a search string, or an order_by field name (dueDate|priority|createdAt).""" a search string, or an order_by field name (dueDate|priority|createdAt)."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( result = await execute_on_client(
action="select", action="select",
table="tasks", table="tasks",
filters={ filters={
"projectId": normalized_project_id or None, "projectId": project_id or None,
"status": status or None, "status": status or None,
"search": search or None, "search": search or None,
"orderBy": order_by or None, "orderBy": order_by or None,
@@ -79,6 +52,7 @@ async def create_task(
due_date: int = 0, due_date: int = 0,
project_id: str = "", project_id: str = "",
is_ai_suggested: int = 0, is_ai_suggested: int = 0,
is_approved: int = 0,
) -> str: ) -> str:
"""Create a new task. """Create a new task.
title: task title (required) title: task title (required)
@@ -89,6 +63,7 @@ async def create_task(
due_date: Unix timestamp in milliseconds; 0 means no due date due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested is_ai_suggested: 1 if proactively suggested, 0 if user-requested
is_approved: 0 until the user confirms; 1 when confirmed
""" """
result = await execute_on_client( result = await execute_on_client(
action="insert", action="insert",
@@ -102,6 +77,7 @@ async def create_task(
"dueDate": due_date or None, "dueDate": due_date or None,
"projectId": project_id or None, "projectId": project_id or None,
"isAiSuggested": is_ai_suggested, "isAiSuggested": is_ai_suggested,
"isApproved": is_approved,
}, },
) )
row = result["row"] row = result["row"]
@@ -121,10 +97,12 @@ async def update_task(
assignees: str = "", assignees: str = "",
due_date: int = -1, due_date: int = -1,
project_id: str = "", project_id: str = "",
is_approved: int = -1,
) -> str: ) -> str:
"""Update fields on an existing task. Only pass fields you want to change. """Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required) task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
is_approved: -1 means unchanged; 0 or 1 sets the value
""" """
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if title: if title:
@@ -141,6 +119,8 @@ async def update_task(
updates["dueDate"] = due_date or None updates["dueDate"] = due_date or None
if project_id: if project_id:
updates["projectId"] = project_id updates["projectId"] = project_id
if is_approved != -1:
updates["isApproved"] = is_approved
result = await execute_on_client( result = await execute_on_client(
action="update", action="update",
table="tasks", table="tasks",
@@ -208,12 +188,8 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
table="taskComments", table="taskComments",
data={"taskId": task_id, "author": author, "content": content}, data={"taskId": task_id, "author": author, "content": content},
) )
row = result.get("row", {}) row = result["row"]
row_author = row.get("author", author) return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool @tool
@@ -223,16 +199,4 @@ async def delete_task_comment(comment_id: str) -> str:
return f"Comment {comment_id} deleted." return f"Comment {comment_id} deleted."
# ── Agent ─────────────────────────────────────────────────────────────
TASK_TOOLS: list[Any] = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]

View File

@@ -1,45 +1,21 @@
"""Timeline agent — project milestone management (list, create, update, delete).""" """Timeline agent — tool definitions for project milestone CRUD."""
from __future__ import annotations from __future__ import annotations
import re
from typing import Any from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - For update_timeline, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all timelines across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool @tool
async def list_timelines(project_id: str = "") -> str: async def list_timelines(project_id: str = "") -> str:
"""List timelines. Provide project_id to scope to a specific project.""" """List timelines. Provide project_id to scope to a specific project."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( result = await execute_on_client(
action="select", action="select",
table="timelines", table="timelines",
filters={"projectId": normalized_project_id or None}, filters={"projectId": project_id or None},
) )
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
@@ -54,12 +30,14 @@ async def create_timeline(
title: str, title: str,
date: int, date: int,
is_ai_suggested: int = 0, is_ai_suggested: int = 0,
is_approved: int = 0,
) -> str: ) -> str:
"""Create a project timeline (milestone). """Create a project timeline (milestone).
project_id: REQUIRED UUID of the parent project project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone title: descriptive name for the milestone
date: Unix timestamp in milliseconds date: Unix timestamp in milliseconds
is_ai_suggested: 1 if proactively suggested, 0 if user-requested is_ai_suggested: 1 if proactively suggested, 0 if user-requested
is_approved: 0 until the user confirms
""" """
result = await execute_on_client( result = await execute_on_client(
action="insert", action="insert",
@@ -69,6 +47,7 @@ async def create_timeline(
"title": title, "title": title,
"date": date, "date": date,
"isAiSuggested": is_ai_suggested, "isAiSuggested": is_ai_suggested,
"isApproved": is_approved,
}, },
) )
row = result["row"] row = result["row"]
@@ -80,16 +59,20 @@ async def update_timeline(
timeline_id: str, timeline_id: str,
title: str = "", title: str = "",
date: int = -1, date: int = -1,
is_approved: int = -1,
) -> str: ) -> str:
"""Update a timeline. Only pass fields that should change. """Update a timeline. Only pass fields that should change.
timeline_id: UUID of the timeline (required) timeline_id: UUID of the timeline (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp) date: -1 means unchanged; any other value sets the new date (ms timestamp)
is_approved: -1 means unchanged; 0 or 1 sets the approval state
""" """
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if title: if title:
updates["title"] = title updates["title"] = title
if date != -1: if date != -1:
updates["date"] = date updates["date"] = date
if is_approved != -1:
updates["isApproved"] = is_approved
result = await execute_on_client( result = await execute_on_client(
action="update", action="update",
table="timelines", table="timelines",
@@ -106,9 +89,4 @@ async def delete_timeline(timeline_id: str) -> str:
return f"Timeline {timeline_id} deleted." return f"Timeline {timeline_id} deleted."
TIMELINE_TOOLS: list[Any] = [
list_timelines,
create_timeline,
update_timeline,
delete_timeline,
]

View File

@@ -55,15 +55,12 @@ async def get_current_user(
raise credentials_exc raise credentials_exc
# Live tier lookup — subscription row is the authoritative source. # Live tier lookup — subscription row is the authoritative source.
# In dev, fall back to 'power' (unlimited) so quota limits don't
# block local development when no Stripe subscription exists.
from app.models import Subscription, User # noqa: PLC0415 from app.models import Subscription, User # noqa: PLC0415
result = await db.execute( result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id) select(Subscription.tier).where(Subscription.user_id == user_id)
) )
default_tier = "power" if settings.ENV == "dev" else "free" tier: str = result.scalar_one_or_none() or "free"
tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname from user row. # Fetch name/surname from user row.
user_result = await db.execute( user_result = await db.execute(

View File

@@ -8,7 +8,8 @@ that could reveal server-side prompt IP:
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …) - Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints - Exact-match known prompt fingerprints
The middleware only activates for paths under /api/v1/chat. Binary responses (storage blobs, backup data) are never touched — the
middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified. names of the fields that were modified.

View File

@@ -1,42 +1,54 @@
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template. """Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
The journey is driven entirely through WebSocket frames (no REST endpoints). Endpoints:
The device WS handler dispatches ``journey_start`` and ``journey_message`` POST /agents/journey/start — start a new journey session
frames to the functions exported here. POST /agents/journey/message — continue the conversation
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
Journey flow: Journey flow:
1. FE sends ``journey_start`` frame with basic agent config (directory, 1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
data_types, schedule). 2. Server creates a session, calls the LLM with a contextual system prompt,
2. Server creates an in-memory session, sets up a WS executor so the and returns the first question.
setup LLM can use file-system tools, does a first directory scrape, 3. Client sends follow-up messages to ``/message``.
and sends back a ``journey_reply`` with the first question. 4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
3. FE sends ``journey_message`` frames for each user reply. delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
4. Server appends the user message, calls the LLM (which may read files 5. Server parses the block, sets ``done=True``, and returns the template.
via tools), and sends back a ``journey_reply``.
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` The ``prompt_template`` from the final response is meant to be stored in
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``. ``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
6. Server parses the block, sends ``journey_reply`` with ``done=True`` by the Electron client (via the agent CRUD endpoints).
and the template. FE stores it locally.
""" """
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import time import time
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from fastapi import APIRouter, Depends, HTTPException, status
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.filesystem_agent import FILESYSTEM_TOOLS from app.api.deps import get_current_user
from app.config.settings import settings
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
from app.core.llm import get_llm from app.core.llm import get_llm
from app.db import get_session
from app.models import CloudAgentConfig, LocalAgentConfig
from app.schemas import (
JourneyMessageRequest,
JourneyResponse,
JourneyStartRequest,
UserProfile,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/agents/journey", tags=["agents"])
# ── Session TTL ─────────────────────────────────────────────────────────── # ── Session TTL ───────────────────────────────────────────────────────────
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes _SESSION_TTL_SECONDS: int = 1800 # 30 minutes
@@ -45,26 +57,18 @@ _SESSION_TTL_SECONDS: int = 1800 # 30 minutes
_TEMPLATE_START = "PROMPT_TEMPLATE_START" _TEMPLATE_START = "PROMPT_TEMPLATE_START"
_TEMPLATE_END = "PROMPT_TEMPLATE_END" _TEMPLATE_END = "PROMPT_TEMPLATE_END"
# Minimum turns before we consider nudging the LLM to wrap up. # Maximum number of conversation turns before the LLM is nudged to wrap up.
_MIN_TURNS_BEFORE_NUDGE: int = 3 _MAX_TURNS: int = 5
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
_MAX_TURNS: int = 15
# Max tool-calling steps per LLM invocation.
_MAX_TOOL_STEPS: int = 6
# ── In-memory session store ─────────────────────────────────────────────── # ── In-memory session store ───────────────────────────────────────────────
@dataclass @dataclass
class JourneySession: class _JourneySession:
session_id: str session_id: str
user_id: str user_id: str
agent_type: str # "local" | "cloud" agent_type: str # "local" | "cloud"
directory: str
data_types: list[str]
history: list[dict[str, Any]] = field(default_factory=list) history: list[dict[str, Any]] = field(default_factory=list)
system_prompt: str = ""
langfuse_prompt: Any = None
created_at: float = field(default_factory=time.monotonic) created_at: float = field(default_factory=time.monotonic)
def is_expired(self) -> bool: def is_expired(self) -> bool:
@@ -72,102 +76,84 @@ class JourneySession:
# session_id → session # session_id → session
_sessions: dict[str, JourneySession] = {} _sessions: dict[str, _JourneySession] = {}
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None: def _get_session(session_id: str, user_id: str) -> _JourneySession:
"""Retrieve session; return None on missing, expired, or wrong owner.""" """Retrieve session; raise 404 on missing, expired, or wrong owner."""
s = _sessions.get(session_id) s = _sessions.get(session_id)
if s is None or s.is_expired(): if s is None or s.is_expired():
_sessions.pop(session_id, None) _sessions.pop(session_id, None)
return None raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
if s.user_id != user_id: if s.user_id != user_id:
return None raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
return s return s
# ── System prompt builder ───────────────────────────────────────────────── # ── System prompt builder ─────────────────────────────────────────────────
_JOURNEY_SYSTEM_PROMPT = """\ _LOCAL_PREAMBLE = """\
What kind of files are in the directories you want to monitor? \
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
_CLOUD_PREAMBLE = """\
What kind of emails or messages should I look for? \
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
_SYSTEM_PROMPT_TEMPLATE = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent. You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand exactly what data the user wants to extract from their Your job is to understand exactly what data the user wants to extract from their {source_description} \
local directory and produce a detailed prompt_template that a separate AI will use and produce a detailed prompt_template that a separate AI will use as its instruction set.
as its instruction set.
The extraction agent already has this base behaviour built in: Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
- Reads each file using file-system tools. 1. The type and format of the source content.
- Creates records (tasks, notes, timelines, projects) via CRUD tools. 2. Which data types to extract: tasks, notes, timelines, and/or projects.
- Sets isAiSuggested=1 on every new record. 3. How fields should be mapped (e.g. email subject → task title).
- Only extracts data explicitly present in the files — it never invents information. 4. Priority or status rules (e.g. "urgent" keyword → high priority).
The user's custom prompt is appended AFTER this base behaviour, so focus on 5. Any special handling, date extraction, or exclusions.
what to look for and how to map it — not on the general extraction mechanics.
You have access to file-system tools to explore the user's directory: After 3-5 questions (when you have enough information), output the final prompt_template between \
- list_directory: to see folder structure these exact markers on their own lines:
- read_file_content: to peek at file contents
- get_file_metadata: to check file info
The user's configured directory is: {directory}
Target data types: {data_types}
IMPORTANT — project assignment is handled automatically by the main agent runner
before the custom prompt is ever used. You MUST NOT ask the user about projects,
projectId, or how to link records to projects. Never include projectId logic or
project creation instructions in the generated prompt_template.
Start by exploring the directory to understand its structure. Then ask concise,
focused questions one at a time. Cover these topics (not necessarily in this order):
1. The type and format of the source content (confirmed by your exploration).
2. How fields should be mapped (e.g. filename → task title).
3. Priority or status rules (e.g. "urgent" keyword → high priority).
4. Any special handling, date extraction, or exclusions.
Once you reach 90% confidence, output the final prompt_template between these exact
markers on their own lines:
{template_start} {template_start}
<the complete extraction prompt here> <the complete extraction prompt here>
{template_end} {template_end}
The prompt_template must be a self-contained instruction for an AI that reads files The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
and must perform CRUD operations using tools to create records. It should specify: and must return a JSON array of records in this shape:
- What entity types to create (tasks, notes, timelines) — never projects. [{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
- How to map file content to record fields (camelCase: title, status, priority,
dueDate, content, etc.) — never include projectId.
- That isAiSuggested must be set to 1 on every new record.
- Concrete examples of mappings based on what you discovered in the directory.
Rules for the generated template:
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
- Include concrete examples of mappings.
- Mention that Electron adds id/createdAt/updatedAt automatically.
- Set isAiSuggested: true and isApproved: false on every record.
{existing_section}\ {existing_section}\
Keep asking clarifying questions until you are at least 90% confident you have Do not ask more than {max_turns} questions total. Start with your first question now.\
enough information to generate an accurate prompt_template. Once you reach that
confidence level, stop asking and produce the final template immediately.
Begin by exploring the directory, then ask your first question.\
""" """
def _build_system_prompt( def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
directory: str, source_description = (
data_types: list[str], "files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
existing_template: str | None = None, )
) -> tuple[str, Any]:
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
existing_section = ( existing_section = (
f"\nThe user already has the following prompt_template — refine it based on their answers:\n" f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
f"---\n{existing_template}\n---\n" f"---\n{existing_template}\n---\n"
if existing_template if existing_template
else "" else ""
) )
template, prompt_obj = get_prompt_or_fallback( return _SYSTEM_PROMPT_TEMPLATE.format(
"journey_system", _JOURNEY_SYSTEM_PROMPT source_description=source_description,
)
compiled = template.format(
directory=directory,
data_types=", ".join(data_types),
template_start=_TEMPLATE_START, template_start=_TEMPLATE_START,
template_end=_TEMPLATE_END, template_end=_TEMPLATE_END,
existing_section=existing_section, existing_section=existing_section,
max_turns=_MAX_TURNS,
) )
return compiled, prompt_obj
def _first_question(agent_type: str) -> str:
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
# ── Template extraction ─────────────────────────────────────────────────── # ── Template extraction ───────────────────────────────────────────────────
@@ -182,42 +168,11 @@ def _extract_template(text: str) -> str | None:
return text[start_idx:end_idx].strip() or None return text[start_idx:end_idx].strip() or None
# ── LLM call with tool support ─────────────────────────────────────────── # ── LLM call ─────────────────────────────────────────────────────────────
def _as_text(content: Any) -> str: async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
if content is None: """Build LangChain messages from history and invoke the LLM."""
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _call_llm_with_tools(
system_prompt: str,
history: list[dict[str, Any]],
tools: list[Any],
*,
user_id: str = "",
session_id: str = "",
langfuse_prompt: Any = None,
) -> str:
"""Build LangChain messages from history and invoke the LLM with tools.
Handles tool-calling loops: if the LLM calls tools, execute them and
continue until a final text response is produced.
"""
lf = get_langfuse()
messages: list[Any] = [SystemMessage(content=system_prompt)] messages: list[Any] = [SystemMessage(content=system_prompt)]
for turn in history: for turn in history:
if turn["role"] == "user": if turn["role"] == "user":
@@ -226,242 +181,137 @@ async def _call_llm_with_tools(
messages.append(AIMessage(content=turn["content"])) messages.append(AIMessage(content=turn["content"]))
llm = get_llm(model=None, temperature=0.4) llm = get_llm(model=None, temperature=0.4)
llm_with_tools = llm.bind_tools(tools) response = await llm.ainvoke(messages)
tool_map = {tool_def.name: tool_def for tool_def in tools} return response.content # type: ignore[return-value]
_span_ctx = (
lf.start_as_current_observation(
as_type="span",
name="journey-setup",
user_id=user_id or None,
session_id=session_id or None,
input=history[-1]["content"] if history else "",
)
if lf else None
)
_span = _span_ctx.__enter__() if _span_ctx else None
try:
for _ in range(_MAX_TOOL_STEPS):
_gen_ctx = (
lf.start_as_current_observation(
as_type="generation",
name="journey-setup-llm",
model=settings.LLM_MODEL,
prompt=langfuse_prompt,
input=messages,
)
if lf else None
)
_gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx:
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
_gen_ctx.__exit__(None, None, None)
messages.append(response)
if not response.tool_calls:
if _span:
_span.update(output=_as_text(response.content))
return _as_text(response.content)
for call in response.tool_calls:
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_setup: journey tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:500],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_setup: journey tool_result name=%s output=%s",
call_name,
str(tool_output)[:800],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max steps.
final = await llm.ainvoke(messages)
final_text = _as_text(final.content)
if _span:
_span.update(output=final_text)
return final_text
finally:
if _span_ctx:
_span_ctx.__exit__(None, None, None)
if lf:
lf.flush()
# ── Journey handlers (called from device_ws.py) ────────────────────────── # ── Existing-config loader ────────────────────────────────────────────────
async def handle_journey_start( async def _load_existing_template(
agent_id: str,
user_id: str, user_id: str,
frame: dict[str, Any], db: AsyncSession,
) -> dict[str, Any]: ) -> str | None:
"""Handle a ``journey_start`` WS frame. """Return the prompt_template of an existing agent config, or None."""
# Try local first, then cloud.
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == user_id,
)
)
local = local_result.scalar_one_or_none()
if local is not None:
return local.prompt_template
Creates a session, runs the setup LLM with directory exploration, cloud_result = await db.execute(
and returns the ``journey_reply`` payload. select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == user_id,
)
)
cloud = cloud_result.scalar_one_or_none()
return cloud.prompt_template if cloud is not None else None
# ── Routes ────────────────────────────────────────────────────────────────
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
async def start_journey(
body: JourneyStartRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> JourneyResponse:
"""Start a new Chatbot Journey session.
If ``agent_id`` is provided the session is pre-seeded with the existing
agent's ``prompt_template`` so the user can refine it.
""" """
agent_type = frame.get("agent_type", "local") # Load existing template (may be None).
directory = frame.get("directory", "") existing_template: str | None = None
data_types = frame.get("data_types", []) if body.agent_id:
existing_template = frame.get("existing_template") existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
# If agent_id was given but not found, proceed without seeding (don't 404 —
# the user may be starting a fresh journey for a not-yet-persisted config).
# Use the session_id provided by the FE so the reply matches the system_prompt = _build_system_prompt(body.agent_type, existing_template)
# listener key; fall back to a generated one if absent. first_question = _first_question(body.agent_type)
session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_template)
session = JourneySession( session_id = str(uuid.uuid4())
session = _JourneySession(
session_id=session_id, session_id=session_id,
user_id=user_id, user_id=current_user.id,
agent_type=agent_type, agent_type=body.agent_type,
directory=directory, # Seed history with the AI's first question so it stays consistent.
data_types=data_types, history=[{"role": "assistant", "content": first_question}],
system_prompt=system_prompt,
langfuse_prompt=langfuse_prompt,
) )
# Store the system prompt inside the session for reuse in /message.
# The LLM will explore the directory using FILESYSTEM_TOOLS via the session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
# ws_context executor (already set by the WS handler before calling us).
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
# require at least one user/input message to be present.
seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
]
ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt,
history=seed_history,
tools=list(FILESYSTEM_TOOLS),
user_id=user_id,
session_id=session_id,
langfuse_prompt=langfuse_prompt,
)
session.history.extend(seed_history)
session.history.append({"role": "assistant", "content": ai_reply})
_sessions[session_id] = session _sessions[session_id] = session
logger.info( logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
"agent_setup: journey session %s started for user %s (directory=%s)", return JourneyResponse(session_id=session_id, message=first_question, done=False)
session_id,
user_id,
directory,
)
# Check if the LLM produced the template on the first turn (unlikely but possible).
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
or "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}
async def handle_journey_message( @router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
user_id: str, async def send_journey_message(
frame: dict[str, Any], body: JourneyMessageRequest,
) -> dict[str, Any]: current_user: UserProfile = Depends(get_current_user),
"""Handle a ``journey_message`` WS frame. db: AsyncSession = Depends(get_session),
) -> JourneyResponse:
"""Send a message in an existing Chatbot Journey session.
Appends the user message, calls the LLM, and returns the The server appends the user's message to the conversation history,
``journey_reply`` payload. calls the LLM, and appends the AI reply. When the LLM wraps up with a
``prompt_template`` block the response includes ``done=True`` and the
extracted template.
""" """
session_id = frame.get("session_id", "") session = _get_session(body.session_id, current_user.id)
message = frame.get("message", "") system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
session = get_journey_session(session_id, user_id) # Append user turn to history.
if session is None: session.history.append({"role": "user", "content": body.message})
return {
"type": "journey_reply",
"session_id": session_id,
"message": "Journey session not found or expired. Please start a new setup.",
"done": True,
"prompt_template": None,
}
# Append user turn. # Call the LLM with the full conversation so far.
session.history.append({"role": "user", "content": message}) ai_reply = await _call_llm(system_prompt, session.history)
# Call the LLM with tools.
ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
user_id=session.user_id,
session_id=session_id,
langfuse_prompt=session.langfuse_prompt,
)
# Append AI turn.
session.history.append({"role": "assistant", "content": ai_reply}) session.history.append({"role": "assistant", "content": ai_reply})
# Check if the LLM produced the final template. # Check if the LLM produced the final template.
prompt_template = _extract_template(ai_reply) prompt_template = _extract_template(ai_reply)
done = prompt_template is not None done = prompt_template is not None
# If the LLM didn't produce a template, nudge it once it has asked enough # Strip the sentinel markers from the message shown to the user.
# questions (>= _MIN_TURNS_BEFORE_NUDGE) or hits the hard safety cap.
if not done:
turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
nudge_content = (
"[System: You have enough information. Please generate the final "
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})
nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
user_id=session.user_id,
session_id=session_id,
langfuse_prompt=session.langfuse_prompt,
)
session.history.append({"role": "assistant", "content": nudge_reply})
prompt_template = _extract_template(nudge_reply)
if prompt_template is not None:
done = True
ai_reply = nudge_reply
display_message = ai_reply display_message = ai_reply
if done: if done:
display_message = ( display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip() ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
if _TEMPLATE_START in ai_reply or "Here is your agent configuration. You can save it or continue refining."
else "Here is your agent configuration. You can save it or continue refining."
) )
_sessions.pop(session_id, None)
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
return { if done:
"type": "journey_reply", logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
"session_id": session_id, # Clean up the session immediately on completion.
"message": display_message, _sessions.pop(body.session_id, None)
"done": done, else:
"prompt_template": prompt_template, # Nudge the LLM to wrap up after max turns.
} turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
# Add a system-level nudge as a hidden user message.
session.history.append({
"role": "user",
"content": (
"[System: You have enough information. Please generate the final "
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
),
})
return JourneyResponse(
session_id=body.session_id,
message=display_message,
done=done,
prompt_template=prompt_template,
)

View File

@@ -1,36 +1,45 @@
"""Agent routes. """Agent CRUD routes: local directory agents and cloud connector agents.
Backend responsibilities are intentionally minimal: Endpoints:
GET /agents/catalog — static catalog for UI display GET /agents/catalog — hardcoded agent type catalog
POST /agents/can-create — billing eligibility check GET /agents/local — list user's local agent configs
POST /agents/triggertrigger a local agent run POST /agents/local create local agent (tier-gated)
PUT /agents/local/{agent_id} — partial update (ownership check)
Agent configuration is owned by the Electron app and is not persisted DELETE /agents/local/{agent_id} — delete + cascade run logs
in backend agent-config tables. GET /agents/cloud — list user's cloud agent configs
POST /agents/cloud — create cloud agent (tier-gated)
PUT /agents/cloud/{agent_id} — partial update (ownership check)
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
GET /agents/runs — paginated run logs (agent_id, page, limit)
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import uuid from datetime import datetime
from datetime import datetime, timedelta, timezone from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select from pydantic import BaseModel
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES from app.billing.tier_manager import FEATURES
from app.core.agent_runner import is_agent_running, run_local_agent from app.core.agent_runner import run_cloud_agent, run_local_agent
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.db import get_session from app.db import get_session
from app.models import AgentRunLog, LocalAgentConfig from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from app.schemas import ( from app.schemas import (
AgentCatalogItem, AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse, AgentRunLogResponse,
AgentTriggerRequest, CloudAgentConfigCreate,
CloudAgentConfigResponse,
CloudAgentConfigUpdate,
LocalAgentConfigCreate,
LocalAgentConfigResponse,
LocalAgentConfigUpdate,
UserProfile, UserProfile,
) )
@@ -47,21 +56,39 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]: # ── Model → schema converters ─────────────────────────────────────────
normalize = {
"task": "tasks", "tasks": "tasks", def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
"note": "notes", "notes": "notes", return LocalAgentConfigResponse(
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines", id=a.id,
"project": "projects", "projects": "projects", name=a.name,
} device_id=a.device_id,
seen: set[str] = set() directory_paths=a.directory_paths,
result: list[str] = [] data_types=a.data_types,
for v in values: prompt_template=a.prompt_template,
mapped = normalize.get(v) file_extensions=a.file_extensions,
if mapped and mapped not in seen: schedule_cron=a.schedule_cron,
seen.add(mapped) enabled=a.enabled,
result.append(mapped) last_run_at=_dt_ms_opt(a.last_run_at),
return result created_at=_dt_ms(a.created_at),
updated_at=_dt_ms(a.updated_at),
)
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
return CloudAgentConfigResponse(
id=a.id,
provider=a.provider, # type: ignore[arg-type]
name=a.name,
data_types=a.data_types,
prompt_template=a.prompt_template,
schedule_cron=a.schedule_cron,
filter_config=a.filter_config,
enabled=a.enabled,
last_run_at=_dt_ms_opt(a.last_run_at),
created_at=_dt_ms(a.created_at),
updated_at=_dt_ms(a.updated_at),
)
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse: def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
@@ -78,42 +105,77 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
) )
def _enforce_agent_limit(tier: str, current_count: int) -> int: # ── Ownership-checked lookups ─────────────────────────────────────────
async def _get_local_agent_for_user(
agent_id: str, user_id: str, db: AsyncSession
) -> LocalAgentConfig:
result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == user_id,
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
return record
async def _get_cloud_agent_for_user(
agent_id: str, user_id: str, db: AsyncSession
) -> CloudAgentConfig:
result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == user_id,
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
return record
# ── Tier limit helper ─────────────────────────────────────────────────
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
"""Return combined enabled local + cloud agent count for the user."""
local_count = (
await db.execute(
select(func.count(LocalAgentConfig.id)).where(
LocalAgentConfig.user_id == user_id,
LocalAgentConfig.enabled == True, # noqa: E712
)
)
).scalar_one()
cloud_count = (
await db.execute(
select(func.count(CloudAgentConfig.id)).where(
CloudAgentConfig.user_id == user_id,
CloudAgentConfig.enabled == True, # noqa: E712
)
)
).scalar_one()
return local_count + cloud_count
def _enforce_agent_limit(tier: str, current_count: int) -> None:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"] limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit: if limit != -1 and current_count >= limit:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.", detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
) )
return limit
async def _enforce_run_frequency( # ── Local page schema (used by runs endpoint) ─────────────────────────
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
today_start = datetime.now(timezone.utc).replace( class _RunsPage(BaseModel):
hour=0, minute=0, second=0, microsecond=0 total: int
) page: int
result = await db.execute( limit: int
select(func.count(AgentRunLog.id)).where( items: list[AgentRunLogResponse]
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ─────────────────────────────────────────────────────────── # ── Catalog ───────────────────────────────────────────────────────────
@@ -147,61 +209,229 @@ async def get_agent_catalog(
] ]
@router.post("/can-create", response_model=AgentCreationCheckResponse) # ── Local agent CRUD ──────────────────────────────────────────────────
async def can_create_agent(
body: AgentCreationCheckRequest, @router.get("/local", response_model=list[LocalAgentConfigResponse])
async def list_local_agents(
current_user: UserProfile = Depends(get_current_user), current_user: UserProfile = Depends(get_current_user),
) -> AgentCreationCheckResponse: db: AsyncSession = Depends(get_session),
"""Check if the user can create one more agent based on billing tier. ) -> list[LocalAgentConfigResponse]:
"""List all local directory agent configs owned by the authenticated user."""
Since configuration is client-owned, the Electron app sends its current result = await db.execute(
active agent count and the backend applies tier limits. select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
"""
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return AgentCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
) )
return [_to_local_response(a) for a in result.scalars().all()]
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED) @router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
async def create_local_agent(
body: LocalAgentConfigCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> LocalAgentConfigResponse:
"""Create a new local directory agent config.
The combined count of enabled local and cloud agents for the user is
checked against the ``batch_active`` limit for their billing tier.
"""
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
agent = LocalAgentConfig(
user_id=current_user.id,
name=body.name,
device_id=body.device_id,
directory_paths=body.directory_paths,
data_types=body.data_types,
prompt_template=body.prompt_template,
file_extensions=body.file_extensions,
schedule_cron=body.schedule_cron,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return _to_local_response(agent)
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
async def update_local_agent(
agent_id: str,
body: LocalAgentConfigUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> LocalAgentConfigResponse:
"""Partially update a local agent config. Only provided fields are changed."""
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
for field, value in body.model_dump(exclude_unset=True).items():
setattr(agent, field, value)
await db.commit()
await db.refresh(agent)
return _to_local_response(agent)
@router.delete("/local/{agent_id}", response_model=dict)
async def delete_local_agent(
agent_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a local agent config. Associated run logs are cascade-deleted."""
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
await db.delete(agent)
await db.commit()
return {"ok": True}
# ── Cloud agent CRUD ──────────────────────────────────────────────────
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
async def list_cloud_agents(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[CloudAgentConfigResponse]:
"""List all cloud connector agent configs owned by the authenticated user."""
result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
)
return [_to_cloud_response(a) for a in result.scalars().all()]
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
async def create_cloud_agent(
body: CloudAgentConfigCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> CloudAgentConfigResponse:
"""Create a new cloud connector agent config.
The combined count of enabled local and cloud agents for the user is
checked against the ``batch_active`` limit for their billing tier.
"""
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
agent = CloudAgentConfig(
user_id=current_user.id,
provider=body.provider,
name=body.name,
data_types=body.data_types,
prompt_template=body.prompt_template,
oauth_token_encrypted=body.oauth_token_encrypted,
schedule_cron=body.schedule_cron,
filter_config=body.filter_config,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return _to_cloud_response(agent)
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
async def update_cloud_agent(
agent_id: str,
body: CloudAgentConfigUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> CloudAgentConfigResponse:
"""Partially update a cloud agent config. Only provided fields are changed."""
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
for field, value in body.model_dump(exclude_unset=True).items():
setattr(agent, field, value)
await db.commit()
await db.refresh(agent)
return _to_cloud_response(agent)
@router.delete("/cloud/{agent_id}", response_model=dict)
async def delete_cloud_agent(
agent_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
await db.delete(agent)
await db.commit()
return {"ok": True}
# ── Run logs ──────────────────────────────────────────────────────────
@router.get("/runs", response_model=_RunsPage)
async def list_run_logs(
agent_id: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=20, ge=1, le=100),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _RunsPage:
"""Return paginated run logs for the authenticated user.
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
"""
base_filter = [AgentRunLog.user_id == current_user.id]
if agent_id:
base_filter.append(AgentRunLog.agent_id == agent_id)
total = (
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
).scalar_one()
result = await db.execute(
select(AgentRunLog)
.where(*base_filter)
.order_by(AgentRunLog.started_at.desc())
.offset((page - 1) * limit)
.limit(limit)
)
items = [_to_run_log_response(log) for log in result.scalars().all()]
return _RunsPage(total=total, page=page, limit=limit, items=items)
# ── Manual trigger stub ───────────────────────────────────────────────
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run( async def trigger_agent_run(
body: AgentTriggerRequest, agent_id: str,
current_user: UserProfile = Depends(get_current_user), current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session), db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse: ) -> AgentRunLogResponse:
"""Trigger a local agent run using client-provided configuration.""" """Manually trigger an agent run.
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
config = LocalAgentConfig( Looks up the agent config (local or cloud) by ID with ownership check,
id=str(uuid.uuid4()), creates a run log entry with ``status="running"``, and returns it.
user_id=current_user.id,
device_id=body.device_id, Actual dispatch to the agent runner is wired in Step 3.4 once
name="Local Directory Monitor", ``DeviceConnectionManager`` and ``agent_runner`` are available.
directory_paths=[body.directory], """
data_types=_to_data_types(body.what_to_extract), # Determine agent type by trying local first, then cloud.
prompt_template=body.custom_agent_prompt, # Keep the full config object so we can pass it to the agent runner.
file_extensions=[], local_config: LocalAgentConfig | None = None
schedule_cron=body.batch_interval, cloud_config: CloudAgentConfig | None = None
enabled=True,
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == current_user.id,
) )
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running. Only one run per agent is allowed at a time.",
) )
local_config = local_result.scalar_one_or_none()
if local_config is not None:
agent_type = "local"
else:
cloud_result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == current_user.id,
)
)
cloud_config = cloud_result.scalar_one_or_none()
if cloud_config is not None:
agent_type = "cloud"
else:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
run_log = AgentRunLog( run_log = AgentRunLog(
agent_id=stable_agent_id, agent_id=agent_id,
agent_type="local", agent_type=agent_type,
user_id=current_user.id, user_id=current_user.id,
status="running", status="running",
) )
@@ -209,14 +439,14 @@ async def trigger_agent_run(
await db.commit() await db.commit()
await db.refresh(run_log) await db.refresh(run_log)
run_context = { # Dispatch the run as a background task — returns 202 immediately.
"type": "agent_batch", if agent_type == "local" and local_config is not None:
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task( asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context) run_local_agent(current_user.id, local_config, run_log, device_manager)
)
elif agent_type == "cloud" and cloud_config is not None:
asyncio.create_task(
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
) )
return _to_run_log_response(run_log) return _to_run_log_response(run_log)

171
app/api/routes/backup.py Normal file
View File

@@ -0,0 +1,171 @@
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
PostgreSQL ``backup_metadata`` table.
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
treating "history" as a ``{backup_id}`` path parameter.
"""
from __future__ import annotations
import uuid
from email.utils import parsedate_to_datetime
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import BackupMetadata as BackupMetadataModel
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/backup", tags=["backup"])
_blob_store = BlobStore()
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total backup bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
BackupMetadataModel.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_backup_quota(
user: UserProfile, size_bytes: int, db: AsyncSession
) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
current = await _current_backup_bytes(user.id, db)
tier_manager.enforce_backup_quota(
user.tier, current_bytes=current, additional_bytes=size_bytes
)
@router.put("")
async def upload_backup(
request: Request,
x_backup_version: int = Header(..., alias="X-Backup-Version"),
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Upload an E2E-encrypted backup blob.
Metadata is passed via custom headers; the raw body is the encrypted blob.
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
await _check_backup_quota(current_user, len(blob), db)
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
)
row = BackupMetadataModel(
id=str(uuid.uuid4()),
user_id=current_user.id,
s3_key=s3_key,
version=x_backup_version,
timestamp=x_backup_timestamp,
checksum=x_backup_checksum,
size_bytes=len(blob),
)
db.add(row)
await db.commit()
return {"ok": True}
@router.get("/history", response_model=list[BackupMetadata])
async def backup_history(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[BackupMetadata]:
"""Return backup metadata records for the authenticated user (no blob bytes)."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
)
rows = result.scalars().all()
return [
BackupMetadata(
version=r.version,
timestamp=r.timestamp,
checksum=r.checksum,
chunk_count=1,
)
for r in rows
]
@router.get("")
async def download_backup(
request: Request,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
.limit(1)
)
latest = result.scalar_one_or_none()
if latest is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
ims_header = request.headers.get("If-Modified-Since")
if ims_header:
try:
ims_dt = parsedate_to_datetime(ims_header)
ims_ms = int(ims_dt.timestamp() * 1000)
if latest.timestamp <= ims_ms:
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
except Exception:
pass # malformed header — ignore and serve the blob
blob = await _blob_store.download(current_user.id, latest.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={
"X-Backup-Version": str(latest.version),
"X-Backup-Timestamp": str(latest.timestamp),
"X-Checksum": latest.checksum,
},
)
@router.delete("/{backup_id}", response_model=dict)
async def delete_backup(
backup_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a specific backup by ID."""
result = await db.execute(
select(BackupMetadataModel).where(
BackupMetadataModel.id == backup_id,
BackupMetadataModel.user_id == current_user.id,
)
)
target = result.scalar_one_or_none()
if target is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
await _blob_store.delete(current_user.id, target.s3_key)
await db.delete(target)
await db.commit()
return {"ok": True}

View File

@@ -1,4 +1,4 @@
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector). """Chat routes: POST /chat (REST fallback).
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device). WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
""" """
@@ -7,53 +7,36 @@ from __future__ import annotations
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.core.deep_agent import run_home from app.core.deep_agent import run_home
from app.core.llm import embed from app.core.memory_middleware import MemoryMiddleware
from app.schemas import ChatRequest, UserProfile from app.db import async_session
from app.schemas import ChatRequest, ChatResponse, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"]) router = APIRouter(prefix="/chat", tags=["chat"])
# ── Embed helpers ─────────────────────────────────────────────────────────
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("") @router.post("")
async def chat( async def chat(
body: ChatRequest, body: ChatRequest,
current_user: UserProfile = Depends(get_current_user), current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse: ) -> JSONResponse:
"""REST fallback for home chat when websocket streaming is unavailable.""" """Route a chat message through the Home deep agent (non-streaming)."""
response = await run_home( async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(current_user.id, body.message)
context = {
**body.context.model_dump(),
**memory_context,
}
response_text = await run_home(
user_id=current_user.id, user_id=current_user.id,
message=body.message, message=body.message,
context=body.context.model_dump(), context=context,
db_session_factory=async_session,
) )
return JSONResponse(content={"response": response}) result = ChatResponse(response=response_text)
return JSONResponse(content=result.model_dump())
@router.post("/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by Electron (vectordb.ts) for local note search.
"""
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

View File

@@ -15,8 +15,8 @@ Protocol:
Incoming frame dispatch: Incoming frame dispatch:
- ``tool_result`` → resolves a pending tool-call Future. - ``tool_result`` → resolves a pending tool-call Future.
- ``journey_start`` → starts a guided setup journey session. - ``agent_data`` enqueued in the per-run agent data queue.
- ``journey_message`` → continues a journey conversation. - ``agent_complete`` → sends None sentinel to close the queue stream.
- ``pong`` → heartbeat acknowledgement (updates last-seen). - ``pong`` → heartbeat acknowledgement (updates last-seen).
- unknown types → logged, ignored. - unknown types → logged, ignored.
@@ -39,13 +39,12 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt from jose import JWTError, jwt
from sqlalchemy import update from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs from app.core.agent_runner import trigger_pending_runs
from app.core.deep_agent import run_floating_stream, run_home_stream
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter from app.core.deep_agent import run_home_stream, run_floating_stream
from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.core.ws_context import clear_client_executor, set_client_executor from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session from app.db import async_session
from app.models import AgentRunLog from app.models import AgentRunLog
@@ -148,6 +147,37 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
"device_ws: tool_result missing id from user=%s", user_id "device_ws: tool_result missing id from user=%s", user_id
) )
elif frame_type == WsFrameType.agent_data:
run_id = frame.get("run_id")
if run_id:
try:
queue = device_manager.get_agent_data_queue(user_id, run_id)
await queue.put(frame)
except RuntimeError:
logger.warning(
"device_ws: agent_data for unknown run user=%s run=%s",
user_id,
run_id,
)
else:
logger.warning(
"device_ws: agent_data missing run_id from user=%s", user_id
)
elif frame_type == WsFrameType.agent_complete:
run_id = frame.get("run_id")
if run_id:
try:
queue = device_manager.get_agent_data_queue(user_id, run_id)
# Sentinel: signals the agent data stream is finished.
await queue.put(None)
except RuntimeError:
pass
else:
logger.warning(
"device_ws: agent_complete missing run_id from user=%s", user_id
)
elif frame_type == WsFrameType.home_request: elif frame_type == WsFrameType.home_request:
asyncio.create_task( asyncio.create_task(
_handle_home_request(websocket, user_id, frame) _handle_home_request(websocket, user_id, frame)
@@ -158,16 +188,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_floating_request(websocket, user_id, frame) _handle_floating_request(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.journey_start:
asyncio.create_task(
_handle_journey_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_message:
asyncio.create_task(
_handle_journey_message(websocket, user_id, frame)
)
elif frame_type == "pong": elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive. # Heartbeat ack — nothing to do, connection is alive.
pass pass
@@ -180,13 +200,35 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
# ── v3 Chat Handlers ────────────────────────────────────────────────── # ── v3 Chat Handlers ──────────────────────────────────────────────────
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
async def _make_ws_executor(websocket: WebSocket, user_id: str): async def _make_ws_executor(websocket: WebSocket, user_id: str):
"""Return a callback that sends tool_call frames and awaits tool_result.""" """Return a callback that sends tool_call frames and awaits tool_result."""
async def _executor(payload: dict) -> dict: async def _executor(payload: dict) -> dict:
payload["type"] = WsFrameType.tool_call payload["type"] = WsFrameType.tool_call
call_id = payload["id"]
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
await websocket.send_text(json.dumps(payload)) await websocket.send_text(json.dumps(payload))
future = device_manager.create_pending_call(user_id, payload["id"]) future = device_manager.create_pending_call(user_id, call_id)
return await future try:
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
except asyncio.TimeoutError:
logger.error(
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
call_id, payload.get("action"), user_id,
)
# Clean up the pending future so it doesn't leak
conn = device_manager._connections.get(user_id)
if conn:
conn.pending_calls.pop(call_id, None)
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
call_id, type(result).__name__,
list(result.keys()) if isinstance(result, dict) else "N/A")
if result is None:
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
return result
return _executor return _executor
@@ -199,27 +241,14 @@ async def _handle_home_request(
request_id = frame.get("request_id") or str(uuid4()) request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "") message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4()) session_id: str = frame.get("session_id") or str(uuid4())
logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
user_id,
request_id,
session_id,
message[:200],
)
# ── Memory: enrich context before LLM call ──────────────────────── # ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db: async with async_session() as db:
memory = MemoryMiddleware(db) memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context( memory_context = await memory.enrich_context(user_id, message)
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = { context: dict = {
"conversation_history": frame.get("conversation_history", []), "conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context, **memory_context,
} }
@@ -227,11 +256,12 @@ async def _handle_home_request(
set_client_executor(executor) set_client_executor(executor)
response_chunks: list[str] = [] response_chunks: list[str] = []
try: try:
event_stream = run_home_stream(user_id, message, context) event_stream = run_home_stream(
formatter = StreamFormatter(request_id=request_id) user_id, message, context, db_session_factory=async_session
)
formatter = HomeFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream): async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr] if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc: except Exception as exc:
@@ -246,14 +276,7 @@ async def _handle_home_request(
async with async_session() as db: async with async_session() as db:
memory = MemoryMiddleware(db) memory = MemoryMiddleware(db)
await memory.store_episode( await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id user_id, session_id, message, "".join(response_chunks)
)
logger.info(
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
) )
@@ -267,37 +290,23 @@ async def _handle_floating_request(
message: str = frame.get("message", "") message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4()) session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {}) scope: dict = frame.get("scope", {})
logger.info(
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
user_id,
request_id,
session_id,
json.dumps(scope, ensure_ascii=True)[:200],
message[:200],
)
# ── Memory: enrich context before LLM call ──────────────────────── # ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db: async with async_session() as db:
memory = MemoryMiddleware(db) memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context( memory_context = await memory.enrich_context(user_id, message)
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = { context: dict = {"scope": scope, **memory_context}
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id) executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor) set_client_executor(executor)
response_chunks: list[str] = [] response_chunks: list[str] = []
try: try:
event_stream = run_floating_stream(user_id, message, context) event_stream = run_floating_stream(
formatter = StreamFormatter(request_id=request_id) user_id, message, context, scope=scope,
db_session_factory=async_session,
)
formatter = FloatingFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream): async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr] if ws_frame.type == "stream_text": # type: ignore[union-attr]
@@ -314,72 +323,8 @@ async def _handle_floating_request(
async with async_session() as db: async with async_session() as db:
memory = MemoryMiddleware(db) memory = MemoryMiddleware(db)
await memory.store_episode( await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id user_id, session_id, message, "".join(response_chunks)
) )
logger.info(
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
# ── v4 Journey Handlers ─────────────────────────────────────────────
async def _handle_journey_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_start frame — explores directory and sends first question."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_start(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
logger.error(
"device_ws: journey_start failed user=%s: %s", user_id, exc
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": frame.get("session_id", ""),
"message": f"Failed to start journey: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
async def _handle_journey_message(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_message frame — continues the journey conversation."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_message(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
session_id = frame.get("session_id", "")
logger.error(
"device_ws: journey_message failed user=%s session=%s: %s",
user_id, session_id, exc,
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey error: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
# ── Heartbeat ───────────────────────────────────────────────────────── # ── Heartbeat ─────────────────────────────────────────────────────────
@@ -415,3 +360,6 @@ async def _mark_runs_disconnected(user_id: str) -> None:
user_id, user_id,
exc, exc,
) )

148
app/api/routes/plugins.py Normal file
View File

@@ -0,0 +1,148 @@
"""Plugins routes: browse and install plugins from the marketplace.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
"""
from __future__ import annotations
from typing import Any, Literal
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db import get_session
from app.marketplace.plugin_registry import registry
from app.marketplace.revenue_share import revenue_share
from app.models import PluginInstallation, PluginReview as PluginReviewModel
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
router = APIRouter(prefix="/plugins", tags=["plugins"])
# ── Tier gate ─────────────────────────────────────────────────────────
def _require_plugin_tier(user: UserProfile) -> None:
"""Raise HTTP 403 for users below Power tier."""
if user.tier not in ("power", "team"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Plugin marketplace requires Power tier or above",
)
# ── Local detail schema ────────────────────────────────────────────────
class _PluginDetail(BaseModel):
plugin: PluginManifest
install_count: int
ratings: list[Any]
# ── Routes ────────────────────────────────────────────────────────────
@router.get("", response_model=PluginListResponse)
async def list_plugins(
category: str | None = Query(default=None),
q: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> PluginListResponse:
"""Browse the plugin marketplace. Requires Power tier or above."""
_require_plugin_tier(current_user)
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
@router.get("/{plugin_id}", response_model=_PluginDetail)
async def get_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _PluginDetail:
"""Get full plugin details including install count. Requires Power tier or above."""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Fetch review ratings for this plugin
review_result = await db.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
)
reviews = review_result.scalars().all()
ratings = [
{
"reviewer_id": r.reviewer_id,
"decision": r.decision,
"notes": r.notes,
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
}
for r in reviews
]
return _PluginDetail(
plugin=entry["manifest"],
install_count=entry["install_count"],
ratings=ratings,
)
@router.post("/{plugin_id}/install", response_model=dict)
async def install_plugin(
plugin_id: str,
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
Requires Power tier or above.
"""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Record the installation in plugin_installations
installation = PluginInstallation(
plugin_id=plugin_id,
user_id=current_user.id,
)
db.add(installation)
await db.flush()
await revenue_share.record_install(
db,
plugin_id=plugin_id,
user_id=current_user.id,
amount_cents=entry["manifest"].price_cents,
)
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
return {"ok": True, "download_url": download_url}
@router.delete("/{plugin_id}/install", response_model=dict)
async def uninstall_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unregister a plugin installation."""
result = await db.execute(
select(PluginInstallation).where(
PluginInstallation.plugin_id == plugin_id,
PluginInstallation.user_id == current_user.id,
)
)
installation = result.scalar_one_or_none()
if installation is not None:
await db.delete(installation)
await db.commit()
await registry.record_uninstall(db, plugin_id)
return {"ok": True}

195
app/api/routes/storage.py Normal file
View File

@@ -0,0 +1,195 @@
"""Storage routes: CRUD for E2E-encrypted cloud records.
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
PostgreSQL ``storage_records`` table.
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import StorageRecord
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/storage", tags=["storage"])
_blob_store = BlobStore()
# ── Local response schemas ─────────────────────────────────────────────
class _CreateResponse(BaseModel):
id: str
created_at: int
class _RecordMeta(BaseModel):
id: str
table: str
checksum: str
created_at: int
updated_at: int
# ── Helpers ────────────────────────────────────────────────────────────
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
StorageRecord.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
current = await _current_usage_bytes(user.id, db)
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
async def _get_record_for_user(
record_id: str, user_id: str, db: AsyncSession
) -> StorageRecord:
"""Look up a record and verify ownership. Returns 404 on mismatch
to prevent user enumeration attacks."""
result = await db.execute(
select(StorageRecord).where(
StorageRecord.id == record_id, StorageRecord.user_id == user_id
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
return record
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
async def create_record(
body: StorageRecordCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
await _check_quota(current_user, len(body.blob), db)
record_id = str(uuid.uuid4())
s3_key = await _blob_store.upload(
current_user.id, body.table, record_id, body.blob, body.checksum
)
record = StorageRecord(
id=record_id,
user_id=current_user.id,
table_name=body.table,
s3_key=s3_key,
checksum=body.checksum,
size_bytes=len(body.blob),
)
db.add(record)
await db.commit()
await db.refresh(record)
created_at_ms = int(record.created_at.timestamp() * 1000)
return _CreateResponse(id=record_id, created_at=created_at_ms)
@router.get("/records", response_model=list[_RecordMeta])
async def list_records(
table: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=50, ge=1, le=200),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[_RecordMeta]:
"""List record metadata for the authenticated user. Blob bytes are never returned."""
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
if table is not None:
query = query.where(StorageRecord.table_name == table)
query = query.offset((page - 1) * limit).limit(limit)
result = await db.execute(query)
rows = result.scalars().all()
return [
_RecordMeta(
id=r.id,
table=r.table_name,
checksum=r.checksum,
created_at=int(r.created_at.timestamp() * 1000),
updated_at=int(r.updated_at.timestamp() * 1000),
)
for r in rows
]
@router.get("/records/{record_id}")
async def download_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
record = await _get_record_for_user(record_id, current_user.id, db)
blob = await _blob_store.download(current_user.id, record.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={"X-Checksum": record.checksum},
)
@router.put("/records/{record_id}", response_model=dict)
async def update_record(
record_id: str,
body: StorageRecordUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Replace the blob for an existing record. Verifies checksum before storing."""
record = await _get_record_for_user(record_id, current_user.id, db)
reject_if_tampered(body.blob, body.checksum)
delta = len(body.blob) - record.size_bytes
if delta > 0:
await _check_quota(current_user, delta, db)
s3_key = await _blob_store.upload(
current_user.id, record.table_name, record_id, body.blob, body.checksum
)
record.s3_key = s3_key
record.checksum = body.checksum
record.size_bytes = len(body.blob)
await db.commit()
return {"ok": True}
@router.delete("/records/{record_id}", response_model=dict)
async def delete_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a record and its S3 blob."""
record = await _get_record_for_user(record_id, current_user.id, db)
await _blob_store.delete(current_user.id, record.s3_key)
await db.delete(record)
await db.commit()
return {"ok": True}

79
app/api/routes/vectors.py Normal file
View File

@@ -0,0 +1,79 @@
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
from __future__ import annotations
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.core.llm import embed
from app.schemas import (
UserProfile,
VectorSearchRequest,
VectorSearchResponse,
VectorUpsertRequest,
)
from app.storage.encryption import reject_if_tampered
from app.storage.vector_store import VectorStore
router = APIRouter(prefix="/storage", tags=["vectors"])
_vector_store = VectorStore()
class _VectorDeleteRequest(BaseModel):
ids: list[str]
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
@router.post("/vectors/upsert", response_model=dict)
async def upsert_vectors(
body: VectorUpsertRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, int]:
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
for item in body.vectors:
reject_if_tampered(item.blob, item.checksum)
await _vector_store.upsert(current_user.id, body.vectors)
return {"upserted": len(body.vectors)}
@router.post("/vectors/search", response_model=VectorSearchResponse)
async def search_vectors(
body: VectorSearchRequest,
current_user: UserProfile = Depends(get_current_user),
) -> VectorSearchResponse:
"""Search the user-scoped vector namespace with an encrypted query blob."""
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
return VectorSearchResponse(results=results)
@router.delete("/vectors", response_model=dict)
async def delete_vectors(
body: _VectorDeleteRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, bool]:
"""Delete vectors by ID, scoped to the authenticated user."""
await _vector_store.delete(current_user.id, body.ids)
return {"ok": True}
@router.post("/vectors/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
"""
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

View File

@@ -21,33 +21,41 @@ FEATURES: dict[str, dict[str, Any]] = {
"free": { "free": {
"agents": 3, "agents": 3,
"batch_active": 2, "batch_active": 2,
"batch_runs_per_day": 5, "cloud_storage_gb": 0,
"backup_gb": 0,
"providers": 1, "providers": 1,
"batch_builder": False, "batch_builder": False,
"plugin_marketplace": False,
"sso": False, "sso": False,
}, },
"pro": { "pro": {
"agents": -1, # unlimited "agents": -1, # unlimited
"batch_active": 10, "batch_active": 10,
"batch_runs_per_day": 50, "cloud_storage_gb": 5,
"backup_gb": 5,
"providers": -1, "providers": -1,
"batch_builder": False, "batch_builder": False,
"plugin_marketplace": False,
"sso": False, "sso": False,
}, },
"power": { "power": {
"agents": -1, "agents": -1,
"batch_active": -1, # unlimited "batch_active": -1, # unlimited
"batch_runs_per_day": -1, # unlimited "cloud_storage_gb": 25,
"backup_gb": 25,
"providers": -1, "providers": -1,
"batch_builder": True, "batch_builder": True,
"plugin_marketplace": True,
"sso": False, "sso": False,
}, },
"team": { "team": {
"agents": -1, "agents": -1,
"batch_active": -1, "batch_active": -1,
"batch_runs_per_day": -1, # unlimited "cloud_storage_gb": -1, # unlimited
"backup_gb": -1, # unlimited
"providers": -1, "providers": -1,
"batch_builder": True, "batch_builder": True,
"plugin_marketplace": True,
"sso": True, "sso": True,
}, },
} }
@@ -69,18 +77,16 @@ class TierManager:
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier: async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
"""Return the current billing tier for ``user_id`` from the DB. """Return the current billing tier for ``user_id`` from the DB.
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod Falls back to ``'free'`` when no subscription row exists.
when no subscription row exists.
""" """
from app.models import Subscription # noqa: PLC0415 from app.models import Subscription # noqa: PLC0415
from app.config.settings import settings # noqa: PLC0415
result = await db.execute( result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id) select(Subscription.tier).where(Subscription.user_id == user_id)
) )
tier: str | None = result.scalar_one_or_none() tier: str | None = result.scalar_one_or_none()
if tier is None or tier not in FEATURES: if tier is None or tier not in FEATURES:
return "power" if settings.ENV == "dev" else "free" return "free"
return tier # type: ignore[return-value] return tier # type: ignore[return-value]
# ── Feature access ─────────────────────────────────────────────────── # ── Feature access ───────────────────────────────────────────────────
@@ -113,6 +119,71 @@ class TierManager:
"""Return the requests-per-minute limit for ``tier``.""" """Return the requests-per-minute limit for ``tier``."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# ── Storage quota ────────────────────────────────────────────────────
def enforce_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
``tier`` is the caller's current tier (from ``current_user.tier``).
``current_bytes`` is the total bytes already stored (queried by caller).
"""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Cloud storage is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
def enforce_backup_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
def check_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False
if limit_gb == -1:
return True
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
# Module-level singleton shared across the app. # Module-level singleton shared across the app.
tier_manager = TierManager() tier_manager = TierManager()

View File

@@ -12,6 +12,17 @@ class Settings(BaseSettings):
STRIPE_SECRET_KEY: str = "" STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = "" STRIPE_WEBHOOK_SECRET: str = ""
S3_BUCKET: str = ""
S3_REGION: str = "us-east-1"
S3_ENDPOINT_URL: str = ""
AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = ""
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
OPENAI_API_KEY: str = "" OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = "" ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = "" GOOGLE_API_KEY: str = ""
@@ -41,10 +52,6 @@ class Settings(BaseSettings):
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"] CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
LANGFUSE_SECRET_KEY: str = ""
LANGFUSE_PUBLIC_KEY: str = ""
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
ENV: Literal["dev", "prod"] = "dev" ENV: Literal["dev", "prod"] = "dev"
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")

View File

@@ -1,30 +0,0 @@
"""Minimal agent base types retained for compatibility with batch runners."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
"""Common base for non-chat agents still using the old base contract."""
def __init__(
self,
user_id: str = "",
shared_memory: dict[str, Any] | None = None,
vector_store_context: list[str] | None = None,
) -> None:
self.user_id = user_id
self.shared_memory: dict[str, Any] = shared_memory or {}
self.vector_store_context: list[str] = vector_store_context or []
@abstractmethod
def get_name(self) -> str: ...
@abstractmethod
def get_description(self) -> str: ...
@property
def skills(self) -> list[str]:
return []

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -3,15 +3,20 @@
Maintains in-memory state for all active Electron → backend WebSocket Maintains in-memory state for all active Electron → backend WebSocket
connections. One connection per user (latest replaces previous). connections. One connection per user (latest replaces previous).
The manager handles the **tool-call round-trip** pattern: The manager participates in two interaction patterns:
- Backend sends ``tool_call`` frame → Electron executes the action →
returns ``tool_result`` frame. 1. **Tool-call round-trip** (bidirectional CRUD):
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
``tool_result`` frame.
- ``create_pending_call`` registers a Future keyed by ``call_id``. - ``create_pending_call`` registers a Future keyed by ``call_id``.
- ``resolve_pending_call`` fulfils the Future; callers awaiting it - ``resolve_pending_call`` fulfils the Future; callers awaiting it
receive the result dict from Electron. receive the result dict from Electron.
This pattern is used by all tools (CRUD, file-system, etc.) via 2. **Agent-data streaming** (local directory agent runs):
``execute_on_client()`` in ``ws_context.py``. - Backend sends ``agent_run`` frame → Electron reads files and sends
back a stream of ``agent_data`` frames followed by ``agent_complete``.
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
a specific ``run_id`` so the agent runner can iterate frames.
The ``device_manager`` module-level singleton is imported by both the The ``device_manager`` module-level singleton is imported by both the
device WS route and the agent runner. device WS route and the agent runner.
@@ -37,6 +42,8 @@ class DeviceConnection:
device_id: str device_id: str
# Futures indexed by tool_call id — resolved when tool_result arrives. # Futures indexed by tool_call id — resolved when tool_result arrives.
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict) pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
# Per-run queues for agent_data / agent_complete frames.
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
class DeviceConnectionManager: class DeviceConnectionManager:
@@ -146,6 +153,31 @@ class DeviceConnectionManager:
if fut is not None and not fut.done(): if fut is not None and not fut.done():
fut.set_result(result) fut.set_result(result)
# ── Agent-data queue ──────────────────────────────────────────────
def get_agent_data_queue(
self, user_id: str, run_id: str
) -> asyncio.Queue[dict | None]:
"""Return (creating if absent) the queue for *run_id* agent frames.
The agent runner reads from this queue. The device WS handler writes
to it. ``None`` is the sentinel that signals the stream is finished.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"get_agent_data_queue: user {user_id!r} is not connected"
)
if run_id not in conn.agent_data_queues:
conn.agent_data_queues[run_id] = asyncio.Queue()
return conn.agent_data_queues[run_id]
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
"""Remove the queue for *run_id* once a run has completed."""
conn = self._connections.get(user_id)
if conn:
conn.agent_data_queues.pop(run_id, None)
# Module-level singleton — import this everywhere. # Module-level singleton — import this everywhere.
device_manager = DeviceConnectionManager() device_manager = DeviceConnectionManager()

View File

@@ -1,114 +0,0 @@
"""Langfuse observability — singleton client and prompt helpers.
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
all helpers are no-ops so the app works without Langfuse configured.
Usage
-----
Tracing::
from app.core.langfuse_client import get_langfuse
lf = get_langfuse()
if lf:
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
span.update(input=user_message)
# ... do work ...
span.update(output=result)
lf.flush()
Prompt management::
from app.core.langfuse_client import get_prompt_or_fallback
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
# Use text as the system prompt; pass prompt_obj to generations for linking.
Linking a prompt to a generation::
with lf.start_as_current_observation(
as_type="generation",
name="llm-call",
model="gpt-4o",
prompt=prompt_obj, # links generation → prompt version in the UI
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=_usage(response))
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
_client: Any = None
_initialized: bool = False
def get_langfuse() -> Any | None:
"""Return the Langfuse singleton, or ``None`` when not configured."""
global _client, _initialized
if _initialized:
return _client
_initialized = True
from app.config.settings import settings # local import to avoid circular deps
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
logger.debug("langfuse: not configured — observability disabled")
return None
try:
from langfuse import Langfuse
_client = Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_HOST,
)
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_HOST)
except Exception as exc:
logger.warning("langfuse: failed to initialize: %s", exc)
_client = None
return _client
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
Returns ``(prompt_text, prompt_obj_or_None)``.
* ``prompt_text`` — the raw template string (variables not yet substituted).
Callers perform variable substitution with Python's ``.format()``.
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
unavailable / the fetch failed. Pass this to generation observations so
Langfuse links the generation to the exact prompt version in the UI.
"""
lf = get_langfuse()
if lf is None:
return fallback, None
try:
prompt = lf.get_prompt(name, label="production", fallback=fallback)
# For text-type prompts .prompt holds the raw template string.
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
return raw, prompt
except Exception as exc:
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
return fallback, None
def extract_usage(response: Any) -> dict[str, int]:
"""Extract token usage from a LangChain AI message into Langfuse format."""
meta = getattr(response, "usage_metadata", None)
if not meta:
return {}
return {
"input": int(meta.get("input_tokens", 0)),
"output": int(meta.get("output_tokens", 0)),
"total": int(meta.get("total_tokens", 0)),
}

View File

@@ -1,6 +1,6 @@
"""LLM factory — centralised model instantiation via LiteLLM. """LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()`` Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
instead of directly constructing a provider-specific class. The model string instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_: <https://docs.litellm.ai/docs/providers>`_:
@@ -18,7 +18,6 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
from __future__ import annotations from __future__ import annotations
import os import os
import warnings
from openai import AsyncOpenAI from openai import AsyncOpenAI
import litellm import litellm
@@ -33,14 +32,6 @@ from app.config.settings import settings
# Drop them silently instead of raising UnsupportedParamsError. # Drop them silently instead of raising UnsupportedParamsError.
litellm.drop_params = True litellm.drop_params = True
# Some provider responses include a plain dict in the `usage` field where a
# richer Pydantic model is expected. This warning is noisy but non-fatal.
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None: def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string.""" """Return the most appropriate API key for the given LiteLLM model string."""

View File

@@ -43,21 +43,15 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware: class MemoryMiddleware:
"""Enrich orchestrator context with memory and persist interactions after.""" """Enrich agent context with memory and persist interactions after."""
def __init__(self, db: AsyncSession) -> None: def __init__(self, db: AsyncSession) -> None:
self._db = db self._db = db
# ── Public API ──────────────────────────────────────────────────────────── # ── Public API ────────────────────────────────────────────────────────────
async def enrich_context( async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
self, """Build memory context dict to inject into the agent before LLM call.
user_id: str,
message: str,
trace_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call.
Returns a dict with keys: Returns a dict with keys:
core_memory — {key: plaintext_value, ...} core_memory — {key: plaintext_value, ...}
@@ -71,21 +65,9 @@ class MemoryMiddleware:
core = await self._load_core(user_id, fernet) core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet) associative = await self._load_associative(user_id, message, fernet)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id) episodic = await self._load_episodic(user_id, fernet)
proactive = await self._load_proactive(user_id, fernet) proactive = await self._load_proactive(user_id, fernet)
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
len(core),
len(associative),
len(episodic),
len(proactive),
)
return { return {
"core_memory": core, "core_memory": core,
"associative_memory": associative, "associative_memory": associative,
@@ -99,7 +81,6 @@ class MemoryMiddleware:
session_id: str, session_id: str,
message: str, message: str,
response: str, response: str,
trace_id: str | None = None,
) -> None: ) -> None:
"""Summarise and store a completed interaction in episodic memory. """Summarise and store a completed interaction in episodic memory.
@@ -122,19 +103,11 @@ class MemoryMiddleware:
self._db.add(row) self._db.add(row)
try: try:
await self._db.commit() await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: store_episode trace=%s user=%s tier=%s session=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
session_id,
)
except Exception as exc: except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc) logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback() await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None: async def update_core(self, user_id: str, key: str, value: str) -> None:
"""Upsert a core memory key/value for a user.""" """Upsert a core memory key/value for a user."""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
@@ -160,176 +133,10 @@ class MemoryMiddleware:
)) ))
try: try:
await self._db.commit() await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: update_core trace=%s user=%s tier=%s key=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
key,
)
except Exception as exc: except Exception as exc:
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc) logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
await self._db.rollback() await self._db.rollback()
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
"""Return core memory as editable blocks (label/value)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryCore)
.where(MemoryCore.user_id == user_id)
.order_by(MemoryCore.key.asc())
)
rows = result.scalars().all()
out: list[dict[str, str]] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out.append({"label": row.key, "value": plaintext})
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
return out
async def get_core_block(self, user_id: str, label: str) -> str | None:
"""Return a single core memory block value by label."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return None
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
return None
value = _safe_decrypt(fernet, row.value_encrypted)
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
return value
async def delete_core(self, user_id: str, label: str) -> bool:
"""Delete a core memory block by label. Returns True if deleted."""
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
return False
await self._db.delete(row)
try:
await self._db.commit()
logger.info("memory: delete_core user=%s label=%s", user_id, label)
return True
except Exception as exc:
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
await self._db.rollback()
return False
async def append_core(self, user_id: str, label: str, content: str) -> None:
"""Append content to a core block, creating it if missing."""
current = await self.get_core_block(user_id, label)
if current is None:
await self.update_core(user_id, label, content)
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
return
await self.update_core(user_id, label, f"{current}\n{content}")
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
"""Replace one exact string inside a core block. Returns False if not found."""
current = await self.get_core_block(user_id, label)
if current is None or old not in current:
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
return False
await self.update_core(user_id, label, current.replace(old, new, 1))
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
return True
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
"""Insert a long-term archival memory entry."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, content)
row = MemoryAssociative(
id=str(uuid.uuid4()),
user_id=user_id,
content_encrypted=encrypted,
embedding=None,
entity_type=source,
entity_id=None,
)
self._db.add(row)
try:
await self._db.commit()
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
except Exception as exc:
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search recall memory (episodic summaries) by keyword."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryEpisodic)
.where(MemoryEpisodic.user_id == user_id)
.order_by(MemoryEpisodic.created_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
# ── Private helpers ─────────────────────────────────────────────────────── # ── Private helpers ───────────────────────────────────────────────────────
async def _get_fernet(self, user_id: str) -> Fernet | None: async def _get_fernet(self, user_id: str) -> Fernet | None:
@@ -341,16 +148,6 @@ class MemoryMiddleware:
return None return None
return Fernet(user.encryption_key.encode()) return Fernet(user.encryption_key.encode())
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
"""Load lightweight user debug fields for trace logs."""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return {"tier": None}
return {
"tier": user.tier,
}
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]: async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute( result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id) select(MemoryCore).where(MemoryCore.user_id == user_id)
@@ -386,17 +183,10 @@ class MemoryMiddleware:
out.append(plaintext) out.append(plaintext)
return out return out
async def _load_episodic( async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
self,
user_id: str,
fernet: Fernet,
session_id: str | None = None,
) -> list[str]:
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
if session_id:
query = query.where(MemoryEpisodic.session_id == session_id)
result = await self._db.execute( result = await self._db.execute(
query select(MemoryEpisodic)
.where(MemoryEpisodic.user_id == user_id)
.order_by(MemoryEpisodic.created_at.desc()) .order_by(MemoryEpisodic.created_at.desc())
.limit(_EPISODIC_RECENT_N) .limit(_EPISODIC_RECENT_N)
) )

View File

@@ -1,47 +1,141 @@
"""Output formatter for deep-agent stream events.""" """Output Formatter — transforms deep-agent event streams into WS frame sequences.
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
* ``("token", str)`` — supervisor text token
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
HomeFormatter:
* Streams text tokens as-is → emits ``WsStreamText``
(text may contain inline ``<type>[id,...]</type>`` entity tags
for the frontend to parse and render as interactive components)
* Attaches mutations → injects into ``WsStreamEnd``
FloatingFormatter:
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
* Streams text tokens → emits ``WsStreamText``
* Attaches mutations → injects into ``WsStreamEnd``
"""
from __future__ import annotations from __future__ import annotations
import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText from app.schemas import (
WsFloatingDomain,
WsStreamEnd,
WsStreamStart,
WsStreamText,
)
logger = logging.getLogger(__name__)
# Map sub-agent tool name → floating domain / entity type
_AGENT_DOMAIN: dict[str, str] = {
"task_agent": "tasks",
"timeline_agent": "timelines",
"note_agent": "notes",
"project_agent": "projects",
}
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
class StreamFormatter: class HomeFormatter:
"""Convert `(event_type, data)` stream events into websocket frame models.""" """Consumes a deep-agent event stream and yields WS frames for the Home view.
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
is responsible for parsing those and rendering interactive components.
Mutations are attached to ``WsStreamEnd``.
"""
def __init__(self, request_id: str) -> None: def __init__(self, request_id: str) -> None:
self.request_id = request_id self.request_id = request_id
self._mutations: list[dict] = []
async def format( async def format(
self, self,
event_stream: AsyncGenerator[tuple[str, Any], None], event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]: ) -> AsyncGenerator[WsFrame, None]:
started = False yield WsStreamStart(request_id=self.request_id)
async for event_type, data in event_stream: async for event_type, data in event_stream:
if event_type == "floating_domain": if event_type == "token":
if isinstance(data, dict): if data:
yield WsStreamText(request_id=self.request_id, chunk=data)
elif event_type == "mutations":
self._mutations = data or []
yield WsStreamEnd(
request_id=self.request_id,
mutations=[
{"action": m["action"], "table": m["table"], "data": m["data"]}
for m in self._mutations
],
)
class FloatingFormatter:
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
``task_agent`` → ``"tasks"``), then streams text tokens as plain
``WsStreamText``. No block parsing for floating context.
"""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._mutations: list[dict] = []
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
domain_sent = False
async for event_type, data in event_stream:
if event_type == "tool_end" and not domain_sent:
# Sniff domain from the first sub-agent that completes
name = data.get("name", "")
domain = _AGENT_DOMAIN.get(name, "tasks")
yield WsFloatingDomain( yield WsFloatingDomain(
request_id=self.request_id, request_id=self.request_id,
domain=data, domain=domain, # type: ignore[arg-type]
) )
continue
if event_type != "token":
continue
if not started:
yield WsStreamStart(request_id=self.request_id) yield WsStreamStart(request_id=self.request_id)
started = True domain_sent = True
text = str(data or "") elif event_type == "token":
if text: if not domain_sent:
yield WsStreamText(request_id=self.request_id, chunk=text) # First token arrived before any tool_end — default domain
yield WsFloatingDomain(
if not started: request_id=self.request_id,
domain="tasks", # type: ignore[arg-type]
)
yield WsStreamStart(request_id=self.request_id) yield WsStreamStart(request_id=self.request_id)
yield WsStreamEnd(request_id=self.request_id) domain_sent = True
if data:
yield WsStreamText(request_id=self.request_id, chunk=data)
elif event_type == "mutations":
self._mutations = data or []
# If no events triggered domain_sent (edge case), still emit structure
if not domain_sent:
yield WsFloatingDomain(
request_id=self.request_id,
domain="tasks", # type: ignore[arg-type]
)
yield WsStreamStart(request_id=self.request_id)
yield WsStreamEnd(
request_id=self.request_id,
mutations=[
{"action": m["action"], "table": m["table"], "data": m["data"]}
for m in self._mutations
],
)

View File

@@ -7,18 +7,21 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
from __future__ import annotations from __future__ import annotations
import logging
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine
from uuid import uuid4 from uuid import uuid4
logger = logging.getLogger(__name__)
# Holds the execute callback for the current WS session. # Holds the execute callback for the current WS session.
# Set by the chat WS handler before the orchestrator runs; cleared after. # Set by the chat WS handler before the deep agent runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar( _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
"_client_executor" "_client_executor"
) )
# Optional collector that captures raw execute_on_client results. # Optional collector that captures raw execute_on_client results.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results. # Set by the deep agent tool loop to capture CRUD mutations.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar( _tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None "_tool_result_collector", default=None
) )
@@ -81,12 +84,17 @@ async def execute_on_client(
if limit is not None: if limit is not None:
payload["limit"] = limit payload["limit"] = limit
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
result = await callback(payload) result = await callback(payload)
if result is None:
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
else:
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
collector = _tool_result_collector.get(None) collector = _tool_result_collector.get(None)
if collector is not None: if collector is not None and action in ("insert", "update", "delete"):
collector.append({ collector.append({
"action": action, "action": action,
"table": table, "table": table,
"data": result, "data": data or {},
}) })
return result return result

View File

@@ -18,9 +18,7 @@ from app.config.settings import settings
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup: ensure agent tool modules are loaded. # Startup: initialise DB connection pool
import app.agents # noqa: F401
yield yield
# Shutdown: dispose SQLAlchemy connection pool # Shutdown: dispose SQLAlchemy connection pool
@@ -50,12 +48,17 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware) app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware) app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, billing, chat, device_ws from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
app.include_router(auth.router, prefix="/api/v1") app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1") app.include_router(chat.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1")
app.include_router(plugins.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1") app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1") app.include_router(agents.router, prefix="/api/v1")
app.include_router(agent_setup.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1") app.include_router(device_ws.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"]) @app.get("/api/v1/health", tags=["health"])

View File

@@ -0,0 +1,7 @@
"""Plugin marketplace package.
Three service classes introduced in Step 10:
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
- ``ReviewQueue`` — approval workflow + security checklist
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
"""

View File

@@ -0,0 +1,212 @@
"""Plugin catalog registry backed by PostgreSQL.
Maintains the authoritative list of plugins, their review status, and
aggregate install counts. All data is persisted in the ``plugins`` table.
Module-level singleton::
from app.marketplace.plugin_registry import registry
"""
from __future__ import annotations
import json
from typing import Any, Literal
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Plugin
from app.schemas import PluginListResponse, PluginManifest
_PAGE_SIZE = 20
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
try:
permissions = json.loads(p.permissions) if p.permissions else []
except (json.JSONDecodeError, TypeError):
permissions = []
return PluginManifest(
id=p.id,
name=p.name,
description=p.description,
version=p.version,
author=p.author_name,
permissions=permissions,
category=p.category,
price_cents=p.price_cents,
)
class PluginRegistry:
"""PostgreSQL-backed plugin catalog.
All methods accept an ``AsyncSession`` parameter so the calling route
controls the session lifecycle.
"""
# ── Queries ──────────────────────────────────────────────────────
async def list_plugins(
self,
db: AsyncSession,
category: str | None = None,
query: str | None = None,
page: int = 1,
sort: Literal["rating", "installs", "newest"] = "newest",
) -> PluginListResponse:
"""Return a page of approved plugins, optionally filtered and sorted."""
base = select(Plugin).where(Plugin.status == "approved")
if category:
base = base.where(Plugin.category == category)
if query:
pattern = f"%{query}%"
base = base.where(
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
)
# Count
count_q = select(func.count()).select_from(base.subquery())
total = (await db.execute(count_q)).scalar_one()
# Sort
if sort == "installs":
base = base.order_by(Plugin.install_count.desc())
elif sort == "rating":
base = base.order_by(Plugin.avg_rating.desc())
else: # newest
base = base.order_by(Plugin.created_at.desc())
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
rows = (await db.execute(base)).scalars().all()
return PluginListResponse(
plugins=[_plugin_to_manifest(r) for r in rows],
total=total,
page=page,
)
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
p = result.scalar_one_or_none()
if p is None:
return None
return {
"manifest": _plugin_to_manifest(p),
"status": p.status,
"install_count": p.install_count,
"avg_rating": p.avg_rating,
}
# ── Mutations ────────────────────────────────────────────────────
async def submit_plugin(
self,
db: AsyncSession,
manifest: PluginManifest,
package_s3_key: str,
) -> str:
"""Add *manifest* to the catalog with ``status='pending_review'``.
Returns the plugin_id. If a plugin with the same id already exists
it is overwritten (re-submission after rejection).
"""
plugin_id = manifest.id
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = existing.scalar_one_or_none()
if row is not None:
row.name = manifest.name
row.description = manifest.description
row.version = manifest.version
row.author_name = manifest.author
row.category = manifest.category
row.price_cents = manifest.price_cents
row.permissions = json.dumps(manifest.permissions)
row.status = "pending_review"
row.s3_package_key = package_s3_key
row.rejection_reason = None
else:
row = Plugin(
id=plugin_id,
name=manifest.name,
description=manifest.description,
version=manifest.version,
author_name=manifest.author,
category=manifest.category,
price_cents=manifest.price_cents,
permissions=json.dumps(manifest.permissions),
status="pending_review",
s3_package_key=package_s3_key,
install_count=0,
avg_rating=0.0,
)
db.add(row)
await db.commit()
return plugin_id
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
"""Set *plugin_id* status to ``'approved'``.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "approved"
row.rejection_reason = None
await db.commit()
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "rejected"
row.rejection_reason = reason
await db.commit()
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
"""Increment the install count for *plugin_id* (no-op if not found)."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = row.install_count + 1
await db.commit()
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
"""Decrement the install count for *plugin_id*, floored at 0."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = max(0, row.install_count - 1)
await db.commit()
# ── Internal helpers used by ReviewQueue ─────────────────────────
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review'."""
result = await db.execute(
select(Plugin).where(Plugin.status == "pending_review")
)
rows = result.scalars().all()
return [
{
"manifest": _plugin_to_manifest(r),
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
}
for r in rows
]
# Module-level singleton
registry = PluginRegistry()

View File

@@ -0,0 +1,125 @@
"""Plugin review workflow backed by PostgreSQL.
Manages the approval queue for newly submitted plugins and enforces a
security checklist before any plugin is made visible in the marketplace.
Module-level singleton::
from app.marketplace.plugin_review import review_queue
"""
from __future__ import annotations
import re
from typing import Any, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import registry
from app.models import PluginReview as PluginReviewModel
from app.schemas import PluginManifest
# ── Security policy ───────────────────────────────────────────────────
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
{
"read:tasks",
"write:tasks",
"read:projects",
"write:projects",
"read:notes",
"write:notes",
"read:timelines",
"write:timelines",
"read:calendar",
"write:calendar",
}
)
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
def validate_manifest(manifest: PluginManifest) -> None:
"""Enforce the plugin security checklist.
Raises:
``ValueError`` on the first violation found. Callers should catch
this and return HTTP 422 / reject the submission.
Checks:
1. Plugin id matches ``^[a-z0-9-]+$``
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
3. No manifest field contains raw binary data
"""
if not _PLUGIN_ID_RE.match(manifest.id):
raise ValueError(
f"Invalid plugin id format: '{manifest.id}'. "
"Only lowercase letters, digits, and hyphens are allowed."
)
for perm in manifest.permissions:
if perm not in ALLOWED_PERMISSIONS:
raise ValueError(
f"Unknown permission: '{perm}'. "
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
)
for field_name, value in manifest.model_dump().items():
if isinstance(value, (bytes, bytearray)):
raise ValueError(
f"Binary content is not allowed in manifest field '{field_name}'."
)
class ReviewQueue:
"""Approval queue for pending plugin submissions.
Delegates status changes to the shared ``PluginRegistry`` singleton.
Review records are persisted in the ``plugin_reviews`` table.
"""
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all plugins currently awaiting review.
Each item is ``{plugin_id, manifest, submitted_at}``.
"""
entries = await registry.get_pending_entries(db)
return [
{
"plugin_id": e["manifest"].id,
"manifest": e["manifest"],
"submitted_at": e["submitted_at"],
}
for e in entries
]
async def submit_review(
self,
db: AsyncSession,
plugin_id: str,
reviewer_id: str,
decision: Literal["approved", "rejected"],
notes: str = "",
) -> None:
"""Record a review decision and update the plugin's status.
Raises:
``KeyError`` if *plugin_id* is not found in the registry.
"""
if decision == "approved":
await registry.approve_plugin(db, plugin_id)
else:
await registry.reject_plugin(db, plugin_id, reason=notes)
review = PluginReviewModel(
plugin_id=plugin_id,
reviewer_id=reviewer_id,
decision=decision,
notes=notes,
)
db.add(review)
await db.commit()
# Module-level singleton
review_queue = ReviewQueue()

View File

@@ -0,0 +1,233 @@
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
Records every plugin installation as a revenue event and facilitates
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
in the ``revenue_events`` table.
Module-level singleton::
from app.marketplace.revenue_share import revenue_share
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from sqlalchemy import extract, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
from app.models import Plugin, RevenueEvent
logger = logging.getLogger(__name__)
# ── Revenue split constants ───────────────────────────────────────────
DEVELOPER_SHARE: float = 0.70
PLATFORM_SHARE: float = 0.30
class RevenueShare:
"""Records installation revenue events and coordinates developer payouts.
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
is not configured, consistent with the rest of the billing layer.
"""
# ── Helpers ──────────────────────────────────────────────────────
@staticmethod
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
@staticmethod
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Core operations ──────────────────────────────────────────────
async def record_install(
self,
db: AsyncSession,
plugin_id: str,
user_id: str,
amount_cents: int,
) -> None:
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
For free plugins (``amount_cents == 0``) no payment is initiated but
the event is still recorded for analytics.
For paid plugins the developer receives 70 % via a Stripe Connect
destination charge. If Stripe is not configured or the charge fails
the installation still succeeds (the event is recorded and the install
count is incremented) — a warning is logged for monitoring.
"""
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
stripe_transfer_id: str | None = None
if amount_cents > 0 and self._stripe_configured():
# Look up the plugin's author Stripe account from the DB
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = result.scalar_one_or_none()
developer_stripe_account: str | None = None
if plugin_row and plugin_row.author_id:
# Future: look up user.stripe_connect_account_id
developer_stripe_account = None # no real account yet
if developer_stripe_account:
try:
s = self._stripe()
transfer = s.Transfer.create(
amount=developer_share_cents,
currency="eur",
destination=developer_stripe_account,
description=f"Revenue share for plugin {plugin_id}",
metadata={"plugin_id": plugin_id, "user_id": user_id},
)
stripe_transfer_id = transfer["id"]
except Exception as exc:
logger.warning(
"Stripe Connect transfer failed for plugin %s: %s",
plugin_id,
exc,
)
else:
logger.debug(
"No Stripe account on file for plugin %s developer; "
"skipping transfer.",
plugin_id,
)
event = RevenueEvent(
plugin_id=plugin_id,
user_id=user_id,
amount_cents=amount_cents,
developer_share_cents=developer_share_cents,
stripe_transfer_id=stripe_transfer_id,
)
db.add(event)
await db.commit()
await registry.record_install(db, plugin_id)
async def get_earnings(
self,
db: AsyncSession,
developer_id: str,
period: str | None = None,
) -> dict[str, Any]:
"""Return aggregated earnings for *developer_id*.
``period`` is an optional ``YYYY-MM`` string to restrict the window.
Returns::
{
"developer_id": str,
"period": str | None,
"total_installs": int,
"total_revenue_cents": int,
"developer_share_cents": int,
}
"""
# Find plugin ids belonging to this developer (by author_name match)
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
plugin_result = await db.execute(plugin_q)
developer_plugin_ids = [row[0] for row in plugin_result.all()]
if not developer_plugin_ids:
return {
"developer_id": developer_id,
"period": period,
"total_installs": 0,
"total_revenue_cents": 0,
"developer_share_cents": 0,
}
query = select(
func.count().label("total_installs"),
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
if period:
# Filter by YYYY-MM: extract year and month from created_at
try:
year, month = period.split("-")
query = query.where(
extract("year", RevenueEvent.created_at) == int(year),
extract("month", RevenueEvent.created_at) == int(month),
)
except ValueError:
pass # invalid period format — return all
result = await db.execute(query)
row = result.one()
return {
"developer_id": developer_id,
"period": period,
"total_installs": row.total_installs,
"total_revenue_cents": row.total_revenue,
"developer_share_cents": row.dev_share,
}
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
Marks processed events with ``paid_at`` timestamp.
Stubs gracefully when Stripe is not configured.
"""
try:
year, month = period.split("-")
year_int, month_int = int(year), int(month)
except ValueError:
logger.warning("Invalid period format: %s", period)
return
result = await db.execute(
select(RevenueEvent).where(
RevenueEvent.plugin_id == plugin_id,
RevenueEvent.paid_at.is_(None),
extract("year", RevenueEvent.created_at) == year_int,
extract("month", RevenueEvent.created_at) == month_int,
)
)
unpaid = list(result.scalars().all())
total_dev_share = sum(e.developer_share_cents for e in unpaid)
if total_dev_share <= 0 or not unpaid:
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
return
if self._stripe_configured():
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = plugin_result.scalar_one_or_none()
developer_stripe_account: str | None = None # Future: fetch from DB
if plugin_row and developer_stripe_account:
try:
s = self._stripe()
s.Transfer.create(
amount=total_dev_share,
currency="eur",
destination=developer_stripe_account,
description=f"Payout for plugin {plugin_id} period {period}",
)
except Exception as exc:
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
return
paid_ts = datetime.now(timezone.utc)
for event in unpaid:
event.paid_at = paid_ts
await db.commit()
# Module-level singleton
revenue_share = RevenueShare()

View File

@@ -1,15 +1,19 @@
"""SQLAlchemy ORM models for all persistent tables. """SQLAlchemy ORM models for all persistent tables.
Only auth, billing, agent config, and memory data live here. Only auth, billing, storage metadata, and marketplace data live here.
User content (notes, tasks, etc.) lives exclusively on the client. User content (notes, tasks, etc.) is NEVER persisted server-side —
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
Table inventory: Table inventory:
users — account credentials + tier users — account credentials + tier
refresh_tokens — hashed refresh token store refresh_tokens — hashed refresh token store
subscriptions — Stripe subscription records subscriptions — Stripe subscription records
local_agent_configs — per-device batch agent configs storage_records — S3 blob metadata (no plaintext)
cloud_agent_configs — OAuth-backed cloud agent configs backup_metadata — encrypted backup manifests
agent_run_logs — execution history for all agents plugins — marketplace plugin catalog
plugin_installations — per-user install records
plugin_reviews — admin review decisions
revenue_events — Stripe Connect 70/30 split ledger
memory_core — per-user persistent key/value preferences (encrypted) memory_core — per-user persistent key/value preferences (encrypted)
memory_associative — per-user semantic memory with embeddings (encrypted) memory_associative — per-user semantic memory with embeddings (encrypted)
memory_episodic — per-user session summaries (encrypted) memory_episodic — per-user session summaries (encrypted)
@@ -22,6 +26,7 @@ import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from sqlalchemy import ( from sqlalchemy import (
BigInteger,
Boolean, Boolean,
DateTime, DateTime,
Enum, Enum,
@@ -31,6 +36,7 @@ from sqlalchemy import (
JSON, JSON,
String, String,
Text, Text,
UniqueConstraint,
Uuid, Uuid,
func, func,
) )
@@ -52,6 +58,8 @@ def _now() -> datetime:
# ── Enum types ──────────────────────────────────────────────────────────── # ── Enum types ────────────────────────────────────────────────────────────
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier") TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
AgentTypeEnum = Enum("local", "cloud", name="agent_type") AgentTypeEnum = Enum("local", "cloud", name="agent_type")
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status") AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider") CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
@@ -129,6 +137,151 @@ class Subscription(Base):
user: Mapped[User] = relationship(back_populates="subscription") user: Mapped[User] = relationship(back_populates="subscription")
class StorageRecord(Base):
__tablename__ = "storage_records"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class BackupMetadata(Base):
__tablename__ = "backup_metadata"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class Plugin(Base):
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
# nullable until developer account system is built
author_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
submitted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
installations: Mapped[list[PluginInstallation]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
reviews: Mapped[list[PluginReview]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
revenue_events: Mapped[list[RevenueEvent]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
class PluginInstallation(Base):
__tablename__ = "plugin_installations"
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="installations")
class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
reviewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
class RevenueEvent(Base):
__tablename__ = "revenue_events"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
class LocalAgentConfig(Base): class LocalAgentConfig(Base):
__tablename__ = "local_agent_configs" __tablename__ = "local_agent_configs"

View File

@@ -50,6 +50,88 @@ class ChatResponse(BaseModel):
response: str response: str
# ── Backup ───────────────────────────────────────────────────────────
class BackupMetadata(BaseModel):
version: int
timestamp: int
checksum: str
chunk_count: int
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
class StorageRecord(BaseModel):
id: str
user_id: str
table: str
blob: bytes
checksum: str
created_at: int
updated_at: int
class StorageRecordCreate(BaseModel):
table: str
blob: bytes
checksum: str
class StorageRecordUpdate(BaseModel):
blob: bytes
checksum: str
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
class VectorItem(BaseModel):
id: str
blob: bytes # encrypted vector + metadata — backend never decrypts
checksum: str
class VectorUpsertRequest(BaseModel):
vectors: list[VectorItem]
class VectorSearchRequest(BaseModel):
query_blob: bytes # encrypted query — backend never decrypts
top_k: int = 10
class VectorSearchResult(BaseModel):
id: str
score: float
blob: bytes
class VectorSearchResponse(BaseModel):
results: list[VectorSearchResult]
# ── Plugin Marketplace ────────────────────────────────────────────────
class PluginManifest(BaseModel):
id: str
name: str
description: str
version: str
author: str
permissions: list[str]
category: str
price_cents: int = 0
class PluginListResponse(BaseModel):
plugins: list[PluginManifest]
total: int
page: int
class PluginInstallRequest(BaseModel):
plugin_id: str
# ── WebSocket Frame Protocol ────────────────────────────────────────── # ── WebSocket Frame Protocol ──────────────────────────────────────────
class WsFrameType(str, Enum): class WsFrameType(str, Enum):
@@ -60,6 +142,9 @@ class WsFrameType(str, Enum):
tool_result = "tool_result" tool_result = "tool_result"
final = "final" final = "final"
ping = "ping" ping = "ping"
agent_run = "agent_run"
agent_data = "agent_data"
agent_complete = "agent_complete"
device_hello = "device_hello" device_hello = "device_hello"
# ── v3 frame types ───────────────────────────────────────────────── # ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request" home_request = "home_request"
@@ -71,10 +156,6 @@ class WsFrameType(str, Enum):
data_request = "data_request" data_request = "data_request"
data_response = "data_response" data_response = "data_response"
mutation = "mutation" mutation = "mutation"
# ── v4 journey frame types ────────────────────────────────────────
journey_start = "journey_start"
journey_message = "journey_message"
journey_reply = "journey_reply"
class WsToolCall(BaseModel): class WsToolCall(BaseModel):
@@ -127,6 +208,31 @@ class WsDeviceHello(BaseModel):
agent_ids: list[str] = Field(default_factory=list) agent_ids: list[str] = Field(default_factory=list)
class WsAgentRun(BaseModel):
"""Server → Client: trigger an agent run on the connected device."""
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
run_id: str
agent_id: str
config: dict[str, Any]
class WsAgentData(BaseModel):
"""Client → Server: files read by the local agent."""
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
run_id: str
files: list[dict[str, Any]]
class WsAgentComplete(BaseModel):
"""Client → Server: Electron signals it has finished reading files."""
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
run_id: str
files_read: int
errors: list[str] = Field(default_factory=list)
# ── WebSocket v3 Frame Models ───────────────────────────────────────── # ── WebSocket v3 Frame Models ─────────────────────────────────────────
@@ -173,14 +279,7 @@ class WsStreamEnd(BaseModel):
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str request_id: str
mutations: list[dict[str, Any]] = Field(default_factory=list)
class WsDomain(BaseModel):
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
class WsFloatingDomain(BaseModel): class WsFloatingDomain(BaseModel):
@@ -188,7 +287,7 @@ class WsFloatingDomain(BaseModel):
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str request_id: str
domain: WsDomain domain: Literal["tasks", "timelines", "notes", "projects"]
# ── Agent Catalog ───────────────────────────────────────────────────── # ── Agent Catalog ─────────────────────────────────────────────────────
@@ -197,28 +296,84 @@ class AgentCatalogItem(BaseModel):
type: str type: str
name: str name: str
description: str description: str
config_schema: dict[str, Any] = Field(default_factory=dict)
class AgentCreationCheckRequest(BaseModel): # ── Local Agent Config ────────────────────────────────────────────────
active_agents: int = Field(ge=0, default=0)
class LocalAgentConfigCreate(BaseModel):
name: str
device_id: str
directory_paths: list[str]
data_types: list[str]
prompt_template: str
file_extensions: list[str]
schedule_cron: str
class AgentCreationCheckResponse(BaseModel): class LocalAgentConfigUpdate(BaseModel):
allowed: bool name: str | None = None
tier: BillingTier device_id: str | None = None
active_agents: int directory_paths: list[str] | None = None
limit: int data_types: list[str] | None = None
prompt_template: str | None = None
file_extensions: list[str] | None = None
schedule_cron: str | None = None
enabled: bool | None = None
class AgentTriggerRequest(BaseModel): class LocalAgentConfigResponse(BaseModel):
directory: str = Field(min_length=1) id: str
device_id: str = Field(default="") name: str
agent_id: str | None = None # FE stable agent ID (electron-store UUID) device_id: str
what_to_extract: list[str] = Field(min_length=1) directory_paths: list[str]
actions_by_type: dict[str, list[str]] | None = None data_types: list[str]
batch_interval: str = Field(min_length=1) prompt_template: str
custom_agent_prompt: str = Field(min_length=1) file_extensions: list[str]
active_agents: int = Field(ge=0, default=0) schedule_cron: str
enabled: bool
last_run_at: int | None
created_at: int
updated_at: int
# ── Cloud Agent Config ────────────────────────────────────────────────
class CloudAgentConfigCreate(BaseModel):
provider: Literal["gmail", "teams", "outlook"]
name: str
data_types: list[str]
prompt_template: str
oauth_token_encrypted: str
schedule_cron: str
filter_config: dict[str, Any] | None = None
class CloudAgentConfigUpdate(BaseModel):
provider: Literal["gmail", "teams", "outlook"] | None = None
name: str | None = None
data_types: list[str] | None = None
prompt_template: str | None = None
oauth_token_encrypted: str | None = None
schedule_cron: str | None = None
filter_config: dict[str, Any] | None = None
enabled: bool | None = None
class CloudAgentConfigResponse(BaseModel):
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
id: str
provider: Literal["gmail", "teams", "outlook"]
name: str
data_types: list[str]
prompt_template: str
schedule_cron: str
filter_config: dict[str, Any] | None
enabled: bool
last_run_at: int | None
created_at: int
updated_at: int
# ── Agent Run Log ───────────────────────────────────────────────────── # ── Agent Run Log ─────────────────────────────────────────────────────
@@ -237,3 +392,18 @@ class AgentRunLogResponse(BaseModel):
# ── Chatbot Journey ─────────────────────────────────────────────────── # ── Chatbot Journey ───────────────────────────────────────────────────
class JourneyStartRequest(BaseModel):
agent_type: Literal["local", "cloud"]
agent_id: str | None = None
class JourneyMessageRequest(BaseModel):
session_id: str
message: str
class JourneyResponse(BaseModel):
session_id: str
message: str
done: bool
prompt_template: str | None = None

1
app/storage/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Cloud storage layer — E2E encrypted blobs and vectors."""

106
app/storage/blob_store.py Normal file
View File

@@ -0,0 +1,106 @@
"""S3-backed store for E2E-encrypted blobs.
Keys are structured as ``{user_id}/{table}/{record_id}``.
The backend never inspects blob content — it stores and retrieves opaque bytes.
"""
from __future__ import annotations
from typing import Any
import boto3
from app.config.settings import settings
class BlobStore:
"""Thin wrapper around boto3 S3.
All blobs must be E2E encrypted by the client before upload.
The backend adds SSE-S3 as an extra layer of at-rest encryption
but cannot decrypt the inner client-side payload.
"""
def _client(self) -> Any:
kwargs: dict[str, Any] = {
"region_name": settings.S3_REGION,
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
}
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
return boto3.client("s3", **kwargs)
@staticmethod
def _key(user_id: str, table: str, record_id: str) -> str:
return f"{user_id}/{table}/{record_id}"
async def upload(
self,
user_id: str,
table: str,
record_id: str,
blob: bytes,
checksum: str,
) -> str:
"""Store *blob* in S3 and return the S3 key.
Args:
user_id: Owner of the blob (used as key prefix).
table: Logical table name (e.g. ``"tasks"``).
record_id: Record UUID.
blob: Raw bytes (pre-encrypted by client).
checksum: SHA-256 hex digest supplied by the client; stored as
object metadata for download-time verification.
Returns:
The S3 key under which the blob was stored.
"""
key = self._key(user_id, table, record_id)
self._client().put_object(
Bucket=settings.S3_BUCKET,
Key=key,
Body=blob,
ServerSideEncryption="AES256", # SSE-S3 at rest
Metadata={"checksum": checksum},
)
return key
async def download(self, user_id: str, s3_key: str) -> bytes:
"""Retrieve the blob stored at *s3_key*.
*user_id* is retained in the signature so higher-level code can
enforce ownership without re-parsing the key.
Raises:
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
object does not exist.
"""
response = self._client().get_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
return response["Body"].read()
async def delete(self, user_id: str, s3_key: str) -> None:
"""Delete the object at *s3_key*.
S3 ``delete_object`` is idempotent — it succeeds even if the key does
not exist.
"""
self._client().delete_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
async def list_keys(self, user_id: str, table: str) -> list[str]:
"""Return all S3 keys for a given user + table combination.
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
"""
prefix = f"{user_id}/{table}/"
response = self._client().list_objects_v2(
Bucket=settings.S3_BUCKET,
Prefix=prefix,
)
return [obj["Key"] for obj in response.get("Contents", [])]

32
app/storage/encryption.py Normal file
View File

@@ -0,0 +1,32 @@
"""Integrity verification only — the backend NEVER decrypts user data."""
from __future__ import annotations
import hashlib
import hmac
from fastapi import HTTPException
def verify_checksum(blob: bytes, checksum: str) -> bool:
"""Return ``True`` if SHA-256(blob) matches *checksum*.
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
timing-based side-channel attacks.
"""
computed = hashlib.sha256(blob).hexdigest()
return hmac.compare_digest(computed, checksum)
def reject_if_tampered(blob: bytes, checksum: str) -> None:
"""Raise ``HTTP 400`` if the blob does not match its checksum.
Call this before storing or forwarding any client-provided blob.
The backend never holds decryption keys — this check only verifies
that the opaque bytes arrived intact.
"""
if not verify_checksum(blob, checksum):
raise HTTPException(
status_code=400,
detail="Checksum mismatch: blob integrity check failed",
)

205
app/storage/vector_store.py Normal file
View File

@@ -0,0 +1,205 @@
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
Vectors are pre-encrypted blobs from the client. The backend stores them
alongside a deterministic 32-dim float representation derived from the blob's
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
is a known trade-off documented in the backend plan.
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
``user_id`` payload field on a shared collection.
"""
from __future__ import annotations
import base64
import hashlib
from typing import Any
from pinecone import Pinecone
from qdrant_client import QdrantClient
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
from app.config.settings import settings
from app.schemas import VectorItem, VectorSearchResult
_QDRANT_COLLECTION = "adiuva_vectors"
def _blob_to_vector(blob: bytes) -> list[float]:
"""Derive a 32-dim float vector from *blob* for storage purposes only.
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
normalises each byte to the range [-1.0, 1.0]. This vector carries no
semantic meaning on encrypted data.
"""
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
class VectorStore:
"""Thin wrapper around Pinecone or Qdrant.
The backend to use is selected at runtime:
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
"""
def _use_pinecone(self) -> bool:
return bool(settings.PINECONE_API_KEY)
# ── Pinecone helpers ──────────────────────────────────────────────
def _pinecone_index(self) -> Any:
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
return pc.Index(settings.PINECONE_INDEX)
# ── Qdrant helpers ────────────────────────────────────────────────
def _qdrant_client(self) -> Any:
return QdrantClient(
url=settings.QDRANT_URL,
api_key=settings.QDRANT_API_KEY or None,
)
# ── Public API ────────────────────────────────────────────────────
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
"""Store encrypted vectors in the backend.
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
so it can be returned verbatim during search.
Args:
user_id: Used as Pinecone namespace or Qdrant payload field.
vectors: List of encrypted vector items from the client.
"""
if self._use_pinecone():
await self._pinecone_upsert(user_id, vectors)
else:
await self._qdrant_upsert(user_id, vectors)
async def search(
self,
user_id: str,
query_blob: bytes,
top_k: int,
) -> list[VectorSearchResult]:
"""Query the vector store and return encrypted result blobs.
The query vector is derived from *query_blob* using the same
deterministic mapping as upsert.
Args:
user_id: Scopes the search to this user's namespace.
query_blob: Encrypted query from the client.
top_k: Maximum number of results to return.
Returns:
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
"""
if self._use_pinecone():
return await self._pinecone_search(user_id, query_blob, top_k)
return await self._qdrant_search(user_id, query_blob, top_k)
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
"""Remove vectors by ID, scoped to *user_id*.
Args:
user_id: Namespace / payload filter to prevent cross-user deletion.
vector_ids: List of vector IDs to remove.
"""
if self._use_pinecone():
await self._pinecone_delete(user_id, vector_ids)
else:
await self._qdrant_delete(user_id, vector_ids)
# ── Pinecone implementation ───────────────────────────────────────
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
index = self._pinecone_index()
records = [
{
"id": v.id,
"values": _blob_to_vector(v.blob),
"metadata": {
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
}
for v in vectors
]
index.upsert(vectors=records, namespace=user_id)
async def _pinecone_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
index = self._pinecone_index()
query_vector = _blob_to_vector(query_blob)
response = index.query(
vector=query_vector,
top_k=top_k,
namespace=user_id,
include_metadata=True,
)
results: list[VectorSearchResult] = []
for match in response.get("matches", []):
blob_bytes = base64.b64decode(match["metadata"]["blob"])
results.append(
VectorSearchResult(
id=match["id"],
score=match["score"],
blob=blob_bytes,
)
)
return results
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
index = self._pinecone_index()
index.delete(ids=vector_ids, namespace=user_id)
# ── Qdrant implementation ─────────────────────────────────────────
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
client = self._qdrant_client()
points = [
PointStruct(
id=v.id,
vector=_blob_to_vector(v.blob),
payload={
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
)
for v in vectors
]
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
async def _qdrant_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
client = self._qdrant_client()
query_vector = _blob_to_vector(query_blob)
hits = client.search(
collection_name=_QDRANT_COLLECTION,
query_vector=query_vector,
query_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=top_k,
)
return [
VectorSearchResult(
id=str(hit.id),
score=hit.score,
blob=base64.b64decode(hit.payload["blob"]),
)
for hit in hits
]
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
client = self._qdrant_client()
client.delete(
collection_name=_QDRANT_COLLECTION,
points_selector=PointIdsList(points=vector_ids),
)

View File

@@ -36,6 +36,37 @@ services:
# image: redis:7-alpine # image: redis:7-alpine
# restart: unless-stopped # restart: unless-stopped
# ── Local S3-compatible storage (MinIO) ──
minio:
image: minio/minio:latest
command: server /data --console-address ":9001"
ports:
- "9000:9000"
- "9001:9001"
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin
volumes:
- minio_data:/data
healthcheck:
test: ["CMD", "mc", "ready", "local"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
# ── Local vector store (Qdrant) ──
qdrant:
image: qdrant/qdrant:latest
ports:
- "6333:6333"
- "6334:6334"
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
volumes: volumes:
postgres_data: postgres_data:
minio_data:
qdrant_data:
copilot_tokens: copilot_tokens:

View File

@@ -1,941 +0,0 @@
# Adiuva — Architettura Microservizi (MVP)
## Panoramica
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
```
┌──────────────┐
│ Cloudflare │
│ (DNS + CDN) │
└──────┬───────┘
│ HTTPS / WSS
┌──────▼───────┐
│ Traefik │
│ API Gateway │
│ (routing, │
│ TLS, rate │
│ limiting) │
└──────┬───────┘
┌──────────┬───────────┼───────────┐
│ │ │ │
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
│ Auth │ │ Chat │ │ Agent │ │Billing │
│ Service │ │Service│ │ Service │ │Service │
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
│ │ │ │
┌─────▼──────────▼──────────▼───────────▼────┐
│ Infrastruttura │
│ PostgreSQL │ Redis │ Qdrant │
└─────────────────────────────────────────────┘
```
---
## 1. Suddivisione dei Servizi
### 1.1 Auth Service (`auth-service`)
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
| Endpoint originale | Metodo |
|---|---|
| `/api/v1/auth/register` | POST |
| `/api/v1/auth/login` | POST |
| `/api/v1/auth/refresh` | POST |
| `/api/v1/auth/me` | GET / PUT |
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
**Modifica chiave — JWT con RS256**:
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
- L'Auth Service firma i JWT con la **chiave privata**.
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
```python
# auth-service/app/auth/jwt.py
from cryptography.hazmat.primitives.asymmetric import rsa
from jose import jwt
PRIVATE_KEY = ... # Da env/secret
PUBLIC_KEY = ... # Derivata o da env
def create_access_token(user_id: str, tier: str) -> str:
return jwt.encode(
{"sub": user_id, "tier": tier, "exp": ...},
PRIVATE_KEY,
algorithm="RS256",
)
```
```python
# shared/auth.py (usato da tutti gli altri servizi)
from jose import jwt
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
def verify_token(token: str) -> dict:
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
```
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
---
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
| Endpoint | Tipo |
|---|---|
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
| `/api/v1/chat` | POST (REST fallback) |
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
**Scaling**: 2N repliche. Sticky cookies per le WS + Redis per cross-instance.
---
### 1.3 Agent Service (`agent-service`) ⭐ Batch
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
| Endpoint | Tipo |
|---|---|
| `/api/v1/agents/catalog` | GET |
| `/api/v1/agents/can-create` | POST |
| `/api/v1/agents/trigger` | POST |
| `/api/v1/agents/journey/start` | POST (o WS relay) |
| `/api/v1/agents/journey/message` | POST (o WS relay) |
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
```
┌──────────────┐ ┌──────────────┐ ┌──────────┐
│ Agent Service│ │ Redis │ │ Chat │
│ (batch run) │ │ │ │ Service │
│ │ │ │ │ (ha WS) │
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
│ from │ │ │ │ al │
│ device │ │ │ │ device│
│ │ │ │ │ via WS│
│ │ SUBSCRIBE │ │ PUBLISH │ │
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
│ risultato │ │ │ │ reply │
└──────────────┘ └──────────────┘ └──────────┘
```
**Scaling**: 1N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
---
### 1.4 Billing Service (`billing-service`)
**Responsabilità**: Stripe checkout, webhook, subscription management.
| Endpoint originale | Metodo |
|---|---|
| `/api/v1/billing/checkout` | POST |
| `/api/v1/billing/webhook` | POST |
| `/api/v1/billing/subscription` | GET / DELETE |
**Database**: Tabelle `subscriptions` (schema `billing`).
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
**Scaling**: 1 replica sufficiente. Basso traffico.
---
### 1.5 Servizi esclusi dall'MVP
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
| Servizio | Responsabilità | Note |
|---|---|---|
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
---
## 2. Tier Check — Dove e Come
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
### Strategia: Tier nel JWT
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
```json
{
"sub": "user_123",
"tier": "pro",
"exp": 1742515200,
"iat": 1742511600
}
```
Ogni servizio:
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
2. Legge `payload["tier"]`**zero chiamate extra**
3. Applica le sue regole di enforcement localmente
```python
# shared/auth.py — dependency FastAPI condivisa
from fastapi import Depends, HTTPException, Request
from jose import jwt
PUBLIC_KEY = ...
class CurrentUser:
def __init__(self, user_id: str, tier: str):
self.user_id = user_id
self.tier = tier
async def get_current_user(request: Request) -> CurrentUser:
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
def require_tier(*allowed_tiers: str):
"""Dependency che blocca se il tier non è tra quelli ammessi."""
async def check(user: CurrentUser = Depends(get_current_user)):
if user.tier not in allowed_tiers:
raise HTTPException(403, "Tier insufficient")
return user
return check
```
### Cosa succede quando il tier cambia (upgrade/downgrade)?
```
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
│ │ │ Service │ (Redis pub/sub) │ Service │
└──────────┘ └──────────┘ └────┬─────┘
UPDATE users
SET tier = 'power'
Al prossimo /refresh
il JWT conterrà tier='power'
```
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 1530 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
### Dove si applica in ciascun servizio
| Servizio | Enforcement |
|---|---|
| **Auth Service** | Nessuno (è lui che scrive il tier) |
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
### Rate-limit distribuito via Redis
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
```python
# shared/middleware/rate_limit.py
import redis.asyncio as aioredis
class DistributedRateLimiter:
def __init__(self, redis: aioredis.Redis):
self._redis = redis
async def check(self, user_id: str, tier: str, service: str) -> bool:
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
max_req = limits.get(tier, 20)
key = f"rate:{service}:{user_id}"
pipe = self._redis.pipeline()
pipe.incr(key)
pipe.expire(key, 60)
count, _ = await pipe.execute()
return count <= max_req
```
---
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
`DeviceConnectionManager` è un **singleton in-memory**:
```python
class DeviceConnectionManager:
def __init__(self):
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
```
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
### La soluzione: Redis Pub/Sub + Registry
```
┌──────────────────────────────────────────────────────────────┐
│ Redis │
│ │
│ Hash: ws:connections │
│ user_123 → instance_A │
│ user_456 → instance_B │
│ │
│ Pub/Sub channels: │
│ tool_call:{user_id} → tool call payloads │
│ tool_result:{call_id} → tool result payloads │
│ stream:{user_id} → text_chunk streaming │
└──────────────────────────────────────────────────────────────┘
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
┌───────────────────────┐ ┌───────────────────────┐
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
│ tool_call:user_123│ │ → user_123 è su A │
│ │ │ │
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
│ da Redis channel │ │ tool_call:user_123 │
│ │ │ {id, action, ...} │
│ 3. Invia al device │ │ │
│ via WS │ │ 4. SUBSCRIBE │
│ │ │ tool_result:{id} │
│ 4. Device risponde │ │ │
│ tool_result │──────────│► 5. Riceve risultato │
│ │ │ │
│ 5. PUBLISH │ │ │
│ tool_result:{id} │ │ │
└───────────────────────┘ └───────────────────────┘
```
### Implementazione: `RedisDeviceManager`
```python
# chat-service/app/core/device_manager.py
import asyncio
import json
import os
import redis.asyncio as aioredis
from dataclasses import dataclass, field
from fastapi import WebSocket
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
@dataclass
class LocalConnection:
ws: WebSocket
device_id: str
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
class RedisDeviceManager:
"""Device manager backed by Redis for cross-instance communication."""
def __init__(self, redis_url: str = "redis://redis:6379"):
self._redis = aioredis.from_url(redis_url)
self._pubsub = self._redis.pubsub()
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
async def start(self):
"""Avvia il listener Redis per tool_call in arrivo."""
asyncio.create_task(self._listen_tool_calls())
# ── Registrazione ──
async def register(self, user_id: str, device_id: str, ws: WebSocket):
# Registra localmente
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
# Registra in Redis quale istanza ha la connessione
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
# Sottoscrivi ai tool_call per questo utente
await self._pubsub.subscribe(f"tool_call:{user_id}")
async def unregister(self, user_id: str):
conn = self._local.pop(user_id, None)
if conn:
for fut in conn.pending_calls.values():
if not fut.done():
fut.cancel()
await self._redis.hdel("ws:connections", user_id)
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
# ── Presenza ──
async def is_online(self, user_id: str) -> bool:
return await self._redis.hexists("ws:connections", user_id)
# ── Tool-call round-trip (cross-instance) ──
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
"""
Invia un tool_call al device dell'utente.
Funziona sia che la WS sia locale che su un'altra istanza.
"""
call_id = payload["id"]
# Caso 1: connessione locale → invio diretto
if user_id in self._local:
conn = self._local[user_id]
loop = asyncio.get_event_loop()
fut: asyncio.Future[dict] = loop.create_future()
conn.pending_calls[call_id] = fut
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
return await asyncio.wait_for(fut, timeout=30.0)
# Caso 2: connessione remota → Redis pub/sub
loop = asyncio.get_event_loop()
fut = loop.create_future()
self._remote_futures[call_id] = fut
# Sottoscrivi al canale di risposta
result_channel = f"tool_result:{call_id}"
await self._pubsub.subscribe(result_channel)
# Pubblica il tool_call
await self._redis.publish(
f"tool_call:{user_id}",
json.dumps(payload),
)
try:
return await asyncio.wait_for(fut, timeout=30.0)
finally:
self._remote_futures.pop(call_id, None)
await self._pubsub.unsubscribe(result_channel)
# ── Risoluzione tool_result (da WS locale) ──
def resolve_local(self, user_id: str, call_id: str, result: dict):
conn = self._local.get(user_id)
if conn:
fut = conn.pending_calls.pop(call_id, None)
if fut and not fut.done():
fut.set_result(result)
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
"""Chiamato quando il device locale invia un tool_result."""
self.resolve_local(user_id, call_id, result)
# Pubblica anche su Redis per l'istanza remota che aspetta
await self._redis.publish(
f"tool_result:{call_id}",
json.dumps(result),
)
# ── Listener Redis ──
async def _listen_tool_calls(self):
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
async for message in self._pubsub.listen():
if message["type"] != "message":
continue
channel = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
data = json.loads(message["data"])
if channel.startswith("tool_call:"):
# Un'altra istanza vuole che inviamo un tool_call al nostro device
user_id = channel.split(":", 1)[1]
conn = self._local.get(user_id)
if conn:
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
elif channel.startswith("tool_result:"):
# Risposta a un tool_call che abbiamo inviato tramite Redis
call_id = channel.split(":", 1)[1]
fut = self._remote_futures.pop(call_id, None)
if fut and not fut.done():
fut.set_result(data)
# ── Stream cross-instance ──
async def publish_stream_chunk(self, user_id: str, chunk: dict):
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
```
---
## 4. Struttura Directory Proposta (MVP)
```
adiuva-api/
├── docker-compose.yml # Orchestrazione completa
├── docker-compose.dev.yml # Override per sviluppo locale
├── shared/ # Codice condiviso (montato come volume)
│ ├── auth.py # JWT verification (chiave pubblica)
│ ├── schemas.py # Pydantic schemas condivisi
│ ├── middleware/
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
│ │ └── sanitizer.py
│ └── models/
│ └── base.py # SQLAlchemy base condivisa
├── auth-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # users, refresh_tokens
│ ├── routes/
│ │ └── auth.py
│ └── services/
│ ├── jwt_service.py # RS256 signing
│ └── user_service.py
├── chat-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # memory_*
│ ├── routes/
│ │ ├── device_ws.py # WS connection owner
│ │ └── chat.py # REST fallback
│ ├── core/
│ │ ├── device_manager.py # RedisDeviceManager
│ │ ├── deep_agent.py # Home + floating chat
│ │ ├── memory_middleware.py
│ │ ├── ws_context.py
│ │ ├── output_formatter.py
│ │ └── llm.py
│ └── agents/ # Tool definitions (used by deep_agent)
│ ├── task_agent.py
│ ├── project_agent.py
│ ├── note_agent.py
│ └── timeline_agent.py
├── agent-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
│ ├── routes/
│ │ ├── agents.py # catalog, can-create, trigger
│ │ └── agent_setup.py # journey start/message
│ ├── core/
│ │ ├── agent_runner.py # Batch classify → process
│ │ ├── agent_registry.py
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
│ │ └── llm.py
│ └── agents/
│ ├── task_agent.py # Tool definitions (batch context)
│ ├── project_agent.py
│ ├── note_agent.py
│ ├── timeline_agent.py
│ └── filesystem_agent.py
├── billing-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # subscriptions
│ ├── routes/
│ │ └── billing.py
│ └── services/
│ ├── stripe_service.py
│ └── tier_manager.py
└── infra/
├── traefik/
│ └── traefik.yml
├── keys/
│ ├── jwt_private.pem # Solo auth-service
│ └── jwt_public.pem # Tutti i servizi
└── alembic/ # Migrazioni condivise o per-servizio
```
---
## 5. Docker Compose — Configurazione MVP
```yaml
# docker-compose.yml
services:
# ══════════════════════════════════════════════════════════
# API Gateway
# ══════════════════════════════════════════════════════════
traefik:
image: traefik:v3.2
command:
- "--api.insecure=true"
- "--providers.docker=true"
- "--providers.docker.exposedbydefault=false"
- "--entrypoints.web.address=:80"
- "--entrypoints.websecure.address=:443"
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
ports:
- "80:80"
- "443:443"
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./infra/certs:/certs:ro
restart: unless-stopped
# ══════════════════════════════════════════════════════════
# Auth Service (2 repliche)
# ══════════════════════════════════════════════════════════
auth-service:
build: ./auth-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
SERVICE_NAME: auth
secrets:
- jwt_private_key
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
- "traefik.http.services.auth.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Chat Service — Real-time WS + Chat (scalabile)
# ══════════════════════════════════════════════════════════
chat-service:
build: ./chat-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: chat
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
# REST chat endpoint
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
- "traefik.http.services.chat.loadbalancer.server.port=8000"
# WebSocket route con sticky session
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
- "traefik.http.routers.ws.service=chat-ws"
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Agent Service — Batch processing (scalabile indipendentemente)
# ══════════════════════════════════════════════════════════
agent-service:
build: ./agent-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: agent
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
- "traefik.http.services.agents.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Billing Service (1 replica)
# ══════════════════════════════════════════════════════════
billing-service:
build: ./billing-service
deploy:
replicas: 1
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: billing
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
- "traefik.http.services.billing.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Infrastruttura
# ══════════════════════════════════════════════════════════
db:
image: pgvector/pgvector:pg16
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: adiuva
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
redis:
image: redis:7-alpine
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 5
restart: unless-stopped
qdrant:
image: qdrant/qdrant:latest
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
secrets:
jwt_private_key:
file: ./infra/keys/jwt_private.pem
jwt_public_key:
file: ./infra/keys/jwt_public.pem
volumes:
postgres_data:
redis_data:
qdrant_data:
```
---
## 6. Configurazione Cloudflare + VPS
### 6.1 DNS
```
api.tuodominio.com → A record → IP del VPS
→ Proxy: ON (orange cloud)
```
### 6.2 Cloudflare Settings
| Setting | Valore | Motivo |
|---------|--------|--------|
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
| Under Attack Mode | Off (attivare se necessario) | |
### 6.3 TLS sul VPS
Due opzioni:
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
```yaml
# traefik.yml — con Cloudflare Origin Certificate
entryPoints:
websecure:
address: ":443"
tls:
certificates:
- certFile: /certs/origin.pem
keyFile: /certs/origin-key.pem
```
### 6.4 Rete VPS
```bash
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
# https://www.cloudflare.com/ips/
ufw default deny incoming
ufw allow from 173.245.48.0/20 to any port 443
ufw allow from 103.21.244.0/22 to any port 443
# ... (tutti gli IP range di Cloudflare)
ufw allow ssh
ufw enable
```
---
## 7. Comunicazione Inter-Servizio
### 7.1 Redis Pub/Sub — Event Bus
```
┌──────────┐ tier_changed:user_123 ┌──────────┐
│ Billing │ ────────────────────────► │ Auth │
│ Service │ │ Service │
└──────────┘ └──────────┘
┌──────────┐ tool_call:user_123 ┌──────────┐
│ Agent │ ────────────────────────► │ Chat │
│ Service │ │ Service │
│ (batch) │ ◄────────────────────────│ (ha WS) │
└──────────┘ tool_result:{call_id} └──────────┘
```
### 7.2 Health Checks e Service Discovery
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
- **Redis pub/sub** (tool-call cross-instance, tier events)
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
- **PostgreSQL** (dati persistenti condivisi)
---
## 8. Piano di Migrazione Incrementale (MVP)
### Fase 1 — Preparazione (nel monolite attuale)
1. Aggiungere Redis al `docker-compose.yml` attuale
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
4. Estrarre `shared/` con auth verification, schemas, middleware
### Fase 2 — Auth Service (primo split)
1. Estrarre `auth.py` routes + models in `auth-service/`
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
4. Il monolite continua a servire tutto il resto
### Fase 3 — Billing Service
1. Estrarre billing routes, Stripe service, tier manager
2. Configurare Redis pub/sub per `tier_changed` events
3. Routare via Traefik
### Fase 4 — Split Chat + Agent (il più delicato)
1. Il monolite residuo contiene WS + chat + agents
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
### Fase 5 — Scaling test
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
3. Monitoring (Prometheus + Grafana) per ogni servizio
---
## 9. Monitoraggio e Logging
```yaml
# Aggiungere al docker-compose.yml
prometheus:
image: prom/prometheus:latest
volumes:
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
restart: unless-stopped
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
volumes:
- grafana_data:/var/lib/grafana
restart: unless-stopped
loki:
image: grafana/loki:latest
restart: unless-stopped
```
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
---
## 10. Sizing VPS Minimo Consigliato (MVP)
| Componente | CPU | RAM | Note |
|---|---|---|---|
| Traefik | 0.25 | 128MB | |
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
| Billing Service | 0.25 | 128MB | |
| PostgreSQL | 1.0 | 1GB | |
| Redis | 0.25 | 256MB | |
| Qdrant | 0.5 | 512MB | |
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
---
## Riepilogo Architettura MVP
| Servizio | Repliche | Proprietario di |
|---|---|---|
| **Traefik** | 1 | Routing, TLS, sticky sessions |
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
| **Chat Service** | 2N | WebSocket, home/floating chat, streaming |
| **Agent Service** | 2N | Batch processing, directory scan, agent setup |
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
| Decisione | Scelta | Motivazione |
|---|---|---|
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
| TLS | Cloudflare Origin Certificate | Zero maintenance |
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
| Storage / Plugin | Post-MVP | Non critici per il lancio |

View File

@@ -4,6 +4,8 @@ gunicorn>=22.0.0
langchain>=0.3.0 langchain>=0.3.0
langchain-openai>=0.3.0 langchain-openai>=0.3.0
langchain-litellm>=0.1.0 langchain-litellm>=0.1.0
langgraph>=0.3.0
deepagents>=0.4.10
litellm>=1.50.0 litellm>=1.50.0
pydantic>=2.10.0 pydantic>=2.10.0
pydantic-settings>=2.7.0 pydantic-settings>=2.7.0
@@ -32,5 +34,4 @@ google-auth-oauthlib>=1.2.0
google-auth-httplib2>=0.2.0 google-auth-httplib2>=0.2.0
msal>=1.28.0 msal>=1.28.0
cryptography>=42.0.0 cryptography>=42.0.0
langfuse>=2.0.0
ruff>=0.8.0 ruff>=0.8.0

View File

@@ -10,13 +10,13 @@ Coverage:
- run_local_agent — file-read timeout path - run_local_agent — file-read timeout path
- run_local_agent — LLM extraction error path - run_local_agent — LLM extraction error path
- run_cloud_agent — stub returns error immediately - run_cloud_agent — stub returns error immediately
- trigger_pending_runs — skipped when config is client-owned - trigger_pending_runs — overdue local + cloud dispatched
- trigger_pending_runs — non-overdue skipped - trigger_pending_runs — non-overdue skipped
- trigger_pending_runs — device_id filter for local agents - trigger_pending_runs — device_id filter for local agents
Integration: Integration:
- POST /agents/can-create — billing eligibility check - POST /agents/{id}/run — 404 on unknown agent
- POST /agents/trigger — creates run log + dispatches background task - POST /agents/{id}/run — creates run log + dispatches background task
""" """
from __future__ import annotations from __future__ import annotations
@@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path():
assert kwargs["items_processed"] == 1 assert kwargs["items_processed"] == 1
assert kwargs["items_created"] == 1 assert kwargs["items_created"] == 1
assert kwargs["errors"] == [] assert kwargs["errors"] == []
assert kwargs["update_config_last_run"] is False assert kwargs["update_config_last_run"] is True
# Verify agent_run frame was sent. # Verify agent_run frame was sent.
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"] agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
@@ -690,11 +690,31 @@ async def test_finalize_run_updates_cloud_config_last_run_at():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_pending_runs_no_overdue(): async def test_trigger_pending_runs_no_overdue():
"""Pending-run scan is skipped because agent config is client-owned.""" """If no agents are overdue trigger_pending_runs does nothing."""
from datetime import timedelta
config = _make_local_config()
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
mock_db_result_local = MagicMock()
mock_db_result_local.scalars.return_value.all.return_value = [config]
mock_db_result_cloud = MagicMock()
mock_db_result_cloud.scalars.return_value.all.return_value = []
mgr = _make_manager() mgr = _make_manager()
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: with patch("app.core.agent_runner.async_session") as mock_session_factory, \
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.execute = AsyncMock(
side_effect=[mock_db_result_local, mock_db_result_cloud]
)
mock_session_factory.return_value = mock_ctx
await trigger_pending_runs(_FREE_UID, "dev-001", mgr) await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called() mock_run.assert_not_called()
@@ -702,11 +722,31 @@ async def test_trigger_pending_runs_no_overdue():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_pending_runs_device_id_filter(): async def test_trigger_pending_runs_device_id_filter():
"""Device filtering is no longer backend-managed in pending runs.""" """Local agents are only triggered for the matching device_id."""
# The DB query already filters by device_id, so we verify the SELECT
# includes the device_id filter by checking that a config bound to a
# different device is never dispatched.
#
# Since trigger_pending_runs queries with device_id == "dev-001",
# simulate the DB returning an empty list (as it would for a mismatch).
mock_db_result_local = MagicMock()
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
mock_db_result_cloud = MagicMock()
mock_db_result_cloud.scalars.return_value.all.return_value = []
mgr = _make_manager(device_id="dev-001") mgr = _make_manager(device_id="dev-001")
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: with patch("app.core.agent_runner.async_session") as mock_session_factory, \
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.execute = AsyncMock(
side_effect=[mock_db_result_local, mock_db_result_cloud]
)
mock_session_factory.return_value = mock_ctx
await trigger_pending_runs(_FREE_UID, "dev-001", mgr) await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called() mock_run.assert_not_called()
@@ -714,18 +754,56 @@ async def test_trigger_pending_runs_device_id_filter():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_pending_runs_dispatches_overdue(): async def test_trigger_pending_runs_dispatches_overdue():
"""No pending runs are dispatched by backend after config deprecation.""" """Overdue local agent triggers run_local_agent sequentially."""
config = _make_local_config() # last_run_at=None → always overdue
mock_db_result_local = MagicMock()
mock_db_result_local.scalars.return_value.all.return_value = [config]
mock_db_result_cloud = MagicMock()
mock_db_result_cloud.scalars.return_value.all.return_value = []
mgr = _make_manager() mgr = _make_manager()
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: call_order: list[str] = []
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
call_order.append("run_local")
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
# First call: query configs. Subsequent calls: create run_log.
mock_query_ctx = AsyncMock()
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
mock_query_ctx.execute = AsyncMock(
side_effect=[mock_db_result_local, mock_db_result_cloud]
)
run_log_obj = AgentRunLog(
id=str(uuid.uuid4()),
agent_id=config.id,
agent_type="local",
user_id=_FREE_UID,
status="running",
started_at=datetime.now(timezone.utc),
)
mock_insert_ctx = AsyncMock()
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
mock_insert_ctx.add = MagicMock()
mock_insert_ctx.commit = AsyncMock()
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
await trigger_pending_runs(_FREE_UID, "dev-001", mgr) await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called() assert call_order == ["run_local"]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Integration: POST /agents/can-create and /agents/trigger # Integration: POST /agents/{id}/run
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -742,67 +820,50 @@ def _override_db(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_can_create_agent_allows_when_under_limit(client): async def test_trigger_run_unknown_agent(client):
"""POST /agents/can-create returns allowed=True when under tier limit.""" """POST /agents/{id}/run returns 404 for unknown agent id."""
resp = client.post( resp = client.post(
"/api/v1/agents/can-create", f"/api/v1/agents/{uuid.uuid4()}/run",
json={"active_agents": 0}, headers=auth_header("power"),
headers=auth_header("free"),
) )
assert resp.status_code == 200 assert resp.status_code == 404
body = resp.json()
assert body["allowed"] is True
assert body["tier"] == "free"
assert body["active_agents"] == 0
assert body["limit"] == 2
@pytest.mark.asyncio
async def test_can_create_agent_denies_when_at_limit(client):
"""POST /agents/can-create returns allowed=False at free-tier limit."""
resp = client.post(
"/api/v1/agents/can-create",
json={"active_agents": 2},
headers=auth_header("free"),
)
assert resp.status_code == 200
body = resp.json()
assert body["allowed"] is False
assert body["limit"] == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_run_local_agent_creates_run_log(client, db_session): async def test_trigger_run_local_agent_creates_run_log(client, db_session):
"""POST /agents/trigger creates a local run log and dispatches background task.""" """POST /agents/{id}/run creates a run log and dispatches a background task."""
dispatched: list[tuple[str, str]] = [] # Create the local agent config in the DB.
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=TEST_USER_IDS["power"],
device_id="dev-001",
name="My Agent",
directory_paths=["/home/user/docs"],
data_types=["tasks"],
prompt_template="Extract tasks.",
file_extensions=[".txt"],
schedule_cron="0 */6 * * *",
enabled=True,
)
db_session.add(config)
await db_session.commit()
dispatched: list = []
async def _fake_run(user_id, cfg, run_log, device_mgr): async def _fake_run(user_id, cfg, run_log, device_mgr):
dispatched.append((user_id, cfg.id)) dispatched.append((user_id, cfg.id))
def _fake_create_task(coro):
coro.close()
return MagicMock()
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \ with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
patch("asyncio.create_task") as mock_create_task: patch("asyncio.create_task") as mock_create_task:
mock_create_task.side_effect = _fake_create_task
resp = client.post( resp = client.post(
"/api/v1/agents/trigger", f"/api/v1/agents/{config.id}/run",
json={
"directory": "/home/user/docs",
"what_to_extract": ["task", "note"],
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
"batch_interval": "0 */6 * * *",
"custom_agent_prompt": "Extract tasks and notes.",
"active_agents": 0,
},
headers=auth_header("power"), headers=auth_header("power"),
) )
assert resp.status_code == 202 assert resp.status_code == 202
data = resp.json() data = resp.json()
assert isinstance(data["agent_id"], str) assert data["agent_id"] == config.id
assert data["agent_id"]
assert data["status"] == "running" assert data["status"] == "running"
assert data["agent_type"] == "local" assert data["agent_type"] == "local"

243
tests/test_backup.py Normal file
View File

@@ -0,0 +1,243 @@
"""Tests for backup routes: upload, download, history, delete.
Exercises the backup lifecycle through the FastAPI TestClient against the
in-memory SQLite test database and moto-mocked S3 bucket.
"""
from __future__ import annotations
import hashlib
from tests.conftest import auth_header, TEST_USER_IDS
# ── Helpers ───────────────────────────────────────────────────────────
_BLOB = b"encrypted-backup-blob-opaque-bytes"
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
_VERSION = 1
_TIMESTAMP = 1700000000000 # arbitrary ms timestamp
def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]:
"""Return auth + backup metadata headers."""
headers = auth_header(tier)
headers["X-Backup-Version"] = str(overrides.get("version", _VERSION))
headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP))
headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM)
headers["Content-Type"] = "application/octet-stream"
return headers
def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821
"""Upload a backup blob and return the response."""
return client.put(
"/api/v1/backup",
content=overrides.pop("blob", _BLOB),
headers=_backup_headers(tier, **overrides),
)
# ── TestUploadBackup ──────────────────────────────────────────────────
class TestUploadBackup:
"""PUT /api/v1/backup"""
def test_upload_success(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power")
assert resp.status_code == 200
assert resp.json() == {"ok": True}
def test_upload_creates_history_entry(self, client, s3_bucket) -> None:
_upload(client, tier="power")
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 1
assert history[0]["version"] == _VERSION
assert history[0]["timestamp"] == _TIMESTAMP
assert history[0]["checksum"] == _CHECKSUM
def test_upload_bad_checksum(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power", checksum="0" * 64)
assert resp.status_code == 400
def test_upload_free_tier_blocked(self, client, s3_bucket) -> None:
"""Free tier has backup_gb=0 → should return 402."""
resp = _upload(client, tier="free")
assert resp.status_code == 402
def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None:
"""Pro tier has backup_gb=5 → small blob succeeds."""
resp = _upload(client, tier="pro")
assert resp.status_code == 200
# ── TestDownloadBackup ────────────────────────────────────────────────
class TestDownloadBackup:
"""GET /api/v1/backup"""
def test_download_latest(self, client, s3_bucket) -> None:
_upload(client, tier="power")
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == _BLOB
assert resp.headers["X-Checksum"] == _CHECKSUM
assert resp.headers["X-Backup-Version"] == str(_VERSION)
def test_download_no_backup_returns_404(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 404
def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None:
"""When If-Modified-Since is after the backup timestamp → 304."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT",
},
)
assert resp.status_code == 304
def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None:
"""When If-Modified-Since is before the backup timestamp → serve blob."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT",
},
)
assert resp.status_code == 200
assert resp.content == _BLOB
def test_download_multiple_returns_latest(self, client, s3_bucket) -> None:
"""When multiple backups exist, GET returns the one with the highest timestamp."""
_upload(client, tier="power", timestamp=1000)
blob2 = b"second-encrypted-backup"
checksum2 = hashlib.sha256(blob2).hexdigest()
_upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2)
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == blob2
# ── TestBackupHistory ─────────────────────────────────────────────────
class TestBackupHistory:
"""GET /api/v1/backup/history"""
def test_history_empty(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup/history", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.json() == []
def test_history_returns_entries(self, client, s3_bucket) -> None:
_upload(client, tier="power", timestamp=1000)
_upload(client, tier="power", timestamp=2000)
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 2
# Ordered by timestamp descending
assert history[0]["timestamp"] == 2000
assert history[1]["timestamp"] == 1000
def test_history_isolated_per_user(self, client, s3_bucket) -> None:
"""One user's backups should not appear in another user's history."""
_upload(client, tier="power")
resp = client.get("/api/v1/backup/history", headers=auth_header("team"))
assert resp.json() == []
# ── TestDeleteBackup ──────────────────────────────────────────────────
class TestDeleteBackup:
"""DELETE /api/v1/backup/{backup_id}"""
def _get_backup_id(self, client, tier="power") -> str:
"""Upload a backup and return its DB id from history."""
_upload(client, tier=tier)
client.get(
"/api/v1/backup/history", headers=auth_header(tier)
).json()
# History returns BackupMetadata schema which doesn't have `id`.
# We need to look it up via a different means.
# Since there's only 1 backup, find via history length.
# Actually the schema doesn't return id — let's verify via re-download.
# We'll use a workaround: upload, then list history to confirm it exists,
# then try to delete — but we need the id...
# Let's check if history includes an id field.
# The schema is: version, timestamp, checksum, chunk_count — no id.
# We'll need to query the DB directly or use a known ID.
# For testing, we'll search history then use the DB.
return None # pragma: no cover — overridden below
def test_delete_success(self, client, s3_bucket, db_session) -> None:
_upload(client, tier="power")
# Discover the backup_id via direct DB query
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("power")
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# History should now be empty
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert history == []
def test_delete_nonexistent(self, client, s3_bucket) -> None:
resp = client.delete(
"/api/v1/backup/no-such-id", headers=auth_header("power")
)
assert resp.status_code == 404
def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None:
"""Cannot delete another user's backup (ownership check returns 404)."""
_upload(client, tier="power")
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
# team user tries to delete power user's backup → 404
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("team")
)
assert resp.status_code == 404

View File

@@ -1,184 +0,0 @@
"""Unit tests for Step 1 file classification (_classify_file).
These tests call the real LLM so they require OPENAI_API_KEY / LLM env vars.
Run with: pytest tests/test_classify_file.py -v
To run a quick manual check against a real file without the full UI:
python -m tests.test_classify_file <path/to/file.txt> [project_name...]
"""
from __future__ import annotations
import asyncio
import sys
import pytest
from app.core.agent_runner import _classify_file
# ── Fixtures ──────────────────────────────────────────────────────────────
PROJECTS_SAMPLE = [
{
"id": "aaaa-0001-0000-0000-000000000001",
"name": "ARPA Sicilia POC",
"status": "active",
"aiSummary": "Proof of concept for AI features targeting ARPA Sicilia agency.",
},
{
"id": "bbbb-0002-0000-0000-000000000002",
"name": "SNAM AI Meeting Prep",
"status": "active",
"aiSummary": "AI-assisted preparation of meeting materials for SNAM.",
},
{
"id": "cccc-0003-0000-0000-000000000003",
"name": "SFERA+ Wave 2",
"status": "active",
"aiSummary": "Second wave of the SFERA+ whitelist project.",
},
]
ARPA_EMAIL = """\
to: roberto.musso@hpe.com; luca.tondin@hpecds.com
isImportance: normal
hasAttachment: True
---
## Body
Buongiorno,
In riferimento alla riunione di ieri sul POC ARPA Sicilia, vi invio il riassunto
dei deliverable concordati:
- Preparare demo entro il 30 marzo
- Condividere documentazione tecnica con il team ARPA
- Fissare call di follow-up la prossima settimana
Cordiali saluti
Roberto Marchetti
"""
SNAM_EMAIL = """\
to: roberto.musso@hpe.com
isImportance: high
hasAttachment: False
---
## Body
Ciao,
ti invio l'agenda per la riunione SNAM di domani.
Per favore conferma la tua presenza.
"""
UNRELATED_EMAIL = """\
to: roberto.musso@hpe.com
isImportance: normal
---
## Body
Benvenuto nel programma HPE Employee Learning Series.
Completa la formazione richiesta entro la fine del trimestre.
"""
# ── Tests ─────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_classify_arpa_matches_existing():
project_id, domains, new_name = await _classify_file(
file_path="arpa_email.txt",
file_content=ARPA_EMAIL,
projects=PROJECTS_SAMPLE,
config_data_types=["tasks", "notes", "timelines"],
)
assert project_id == "aaaa-0001-0000-0000-000000000001", (
f"Expected ARPA project, got project_id={project_id!r} new_name={new_name!r}"
)
assert new_name is None
@pytest.mark.asyncio
async def test_classify_snam_matches_existing():
project_id, domains, new_name = await _classify_file(
file_path="snam_email.txt",
file_content=SNAM_EMAIL,
projects=PROJECTS_SAMPLE,
config_data_types=["tasks", "notes"],
)
assert project_id == "bbbb-0002-0000-0000-000000000002", (
f"Expected SNAM project, got project_id={project_id!r} new_name={new_name!r}"
)
@pytest.mark.asyncio
async def test_classify_unrelated_returns_new():
project_id, domains, new_name = await _classify_file(
file_path="learning_email.txt",
file_content=UNRELATED_EMAIL,
projects=PROJECTS_SAMPLE,
config_data_types=["tasks", "notes"],
)
assert project_id == "new"
assert new_name is not None # LLM should suggest a name
@pytest.mark.asyncio
async def test_classify_empty_file_returns_new():
project_id, domains, new_name = await _classify_file(
file_path="empty.txt",
file_content=" ",
projects=PROJECTS_SAMPLE,
config_data_types=["tasks"],
)
assert project_id == "new"
@pytest.mark.asyncio
async def test_classify_no_projects_returns_new():
project_id, domains, new_name = await _classify_file(
file_path="arpa_email.txt",
file_content=ARPA_EMAIL,
projects=[],
config_data_types=["tasks", "notes"],
)
assert project_id == "new"
assert new_name is not None
# ── CLI quick-test runner ─────────────────────────────────────────────────
async def _cli_test(file_path: str, project_names: list[str]) -> None:
"""Run Step 1 classification against a real file from the CLI."""
import json
from pathlib import Path
content = Path(file_path).read_text(encoding="utf-8", errors="replace")
projects = [
{"id": f"test-id-{i:04d}", "name": name, "status": "active", "aiSummary": ""}
for i, name in enumerate(project_names)
]
print(f"\nClassifying: {file_path}")
print(f"Projects in context: {[p['name'] for p in projects]}\n")
project_id, domains, new_name = await _classify_file(
file_path=file_path,
file_content=content,
projects=projects,
config_data_types=["tasks", "notes", "timelines"],
)
result = {
"project_id": project_id,
"matched_name": next((p["name"] for p in projects if p["id"] == project_id), None),
"new_project_name": new_name,
"domains": domains,
}
print(json.dumps(result, indent=2, ensure_ascii=False))
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python -m tests.test_classify_file <file_path> [project_name ...]")
sys.exit(1)
asyncio.run(_cli_test(sys.argv[1], sys.argv[2:]))

View File

@@ -1,288 +0,0 @@
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
from __future__ import annotations
from datetime import date, timedelta
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import (
_infer_floating_domain,
_normalize_tagged_list_lines,
run_floating,
run_floating_stream,
run_home,
)
class _FakeTool:
name = "list_tasks"
async def ainvoke(self, args):
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
class _FakeLLM:
def __init__(self) -> None:
self.agent_calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, messages):
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
if "strict domain classifier" in system_prompt:
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
self.agent_calls += 1
if self.agent_calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {"project_id": "proj-1"},
}
],
)
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert tool_messages, "Expected at least one tool message"
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
async def astream(self, _messages):
yield SimpleNamespace(content="stream-")
yield SimpleNamespace(content="ok")
@pytest.mark.asyncio
async def test_run_home_uses_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
out = await run_home("user-1", "list my tasks", {})
assert "Final answer from mocked tool" in out
assert "Mock Task" in out
@pytest.mark.asyncio
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"show me timeline updates",
{"scope": {"type": "timeline", "id": "tl-1"}},
):
events.append(event)
assert events[0] == (
"floating_domain",
{"type": "timeline", "id": "tl-1", "section": None},
)
assert ("token", "stream-") in events
assert ("token", "ok") in events
@pytest.mark.asyncio
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
class _ClassifierOnlyLLM:
async def ainvoke(self, _messages):
return AIMessage(
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
)
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X",
{
"scope": {"type": "timeline"},
"resolved_project_id": "213213-312321-312312-421321",
},
)
assert domain == {
"type": "project",
"id": "213213-312321-312312-421321",
"section": "task",
}
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
raw = (
"Certo!\n\n"
"1. **Task A** — priorita high <task>[task-1]</task>\n"
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
)
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
assert "<task>[task-1]</task>" in out
assert "<task>[task-2]</task>" in out
assert "Task A" not in out
assert "Task B" not in out
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
today = date.today()
tomorrow = today + timedelta(days=1)
yesterday = today - timedelta(days=1)
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
raw = "\n".join(
[
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
]
)
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
assert "<timeline>[tl-next]</timeline>" in out
assert "<timeline>[tl-old]</timeline>" not in out
assert "<timeline>[tl-future]</timeline>" not in out
@pytest.mark.asyncio
async def test_run_floating_strips_xml_like_tags_from_final_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return (
"Hai 1 task:\\n"
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
)
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert "<task>" not in text
assert "</task>" not in text
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
@pytest.mark.asyncio
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "Hai 1 task:\\n"
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
token_events = [str(data) for event_type, data in events if event_type == "token"]
combined = "".join(token_events)
assert "<task>" not in combined
assert "</task>" not in combined
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
@pytest.mark.asyncio
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
class _NoChunkLLM:
def __init__(self) -> None:
self.calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, _messages):
self.calls += 1
if self.calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {},
}
],
)
return AIMessage(content="No notes found.")
async def astream(self, _messages):
if False:
yield None
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"quali sono le note?",
{"scope": {"type": "note"}},
):
events.append(event)
assert events[0][0] == "floating_domain"
assert ("token", "No notes found.") in events
@pytest.mark.asyncio
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert text == "No results found."
@pytest.mark.asyncio
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
assert ("token", "No results found.") in events

View File

@@ -110,32 +110,6 @@ async def test_enrich_context_returns_episodic_memory(db_session, user_with_key)
assert any("Q1 tasks" in s for s in ctx["episodic_memory"]) assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
@pytest.mark.asyncio
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
target_session = str(uuid.uuid4())
other_session = str(uuid.uuid4())
db_session.add(MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=_enc("Target session memory"),
session_id=target_session,
))
db_session.add(MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=_enc("Other session memory"),
session_id=other_session,
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
episodic = ctx.get("episodic_memory", [])
assert any("Target session" in s for s in episodic)
assert not any("Other session" in s for s in episodic)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key): async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
# Add one pattern above threshold and one below # Add one pattern above threshold and one below
@@ -255,40 +229,6 @@ async def test_update_core_upsert(db_session, user_with_key):
assert _dec(rows[0].value_encrypted) == "fr" assert _dec(rows[0].value_encrypted) == "fr"
@pytest.mark.asyncio
async def test_core_block_edit_ops(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.update_core(USER_ID, "human", "Name: Roberto")
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
blocks = await middleware.list_core_blocks(USER_ID)
human = next(b for b in blocks if b["label"] == "human")
assert replaced is True
assert "Name: Robert" in human["value"]
assert "Timezone: Europe/Rome" in human["value"]
deleted = await middleware.delete_core(USER_ID, "human")
assert deleted is True
assert await middleware.get_core_block(USER_ID, "human") is None
@pytest.mark.asyncio
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
assert any("whitelist" in item.lower() for item in arch)
assert any("delayed" in item.lower() for item in rec)
# ── End-to-end WS: memory middleware is called during home_request ──────────── # ── End-to-end WS: memory middleware is called during home_request ────────────
def test_home_request_calls_memory_middleware(client): def test_home_request_calls_memory_middleware(client):
@@ -300,20 +240,21 @@ def test_home_request_calls_memory_middleware(client):
def __init__(self, db): def __init__(self, db):
pass pass
async def enrich_context(self, user_id, message, **kwargs): async def enrich_context(self, user_id, message):
enrich_calls.append((user_id, message)) enrich_calls.append((user_id, message))
return {"core_memory": {"tz": "UTC"}} return {"core_memory": {"tz": "UTC"}}
async def store_episode(self, user_id, session_id, message, response, **kwargs): async def store_episode(self, user_id, session_id, message, response):
store_calls.append((user_id, session_id, message, response)) store_calls.append((user_id, session_id, message, response))
token = make_jwt("power", user_id=USER_ID) token = make_jwt("power", user_id=USER_ID)
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
async def _mock_stream(user_id, message, context): async def _mock_stream(user_id, message, context, db_session_factory=None):
# Verify memory context was injected # Verify memory context was injected
assert context.get("core_memory") == {"tz": "UTC"} assert context.get("core_memory") == {"tz": "UTC"}
yield "token", "Done" yield ("token", "Done")
yield ("mutations", [])
with ( with (
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware), patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),

View File

@@ -1,82 +1,214 @@
"""Tests for app.core.output_formatter.StreamFormatter.""" """Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
from __future__ import annotations from __future__ import annotations
import pytest import pytest
from app.core.output_formatter import StreamFormatter from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText from app.schemas import (
WsFloatingDomain,
WsStreamEnd,
WsStreamStart,
WsStreamText,
)
# ── helpers ───────────────────────────────────────────────────────────────────
async def _stream(*events: tuple[str, object]): async def _stream(*events: tuple[str, object]):
"""Async generator that yields (event_type, data) tuples."""
for event in events: for event in events:
yield event yield event
async def _collect(formatter: StreamFormatter, event_stream): async def collect(formatter, event_stream):
frames = [] frames = []
async for frame in formatter.format(event_stream): async for frame in formatter.format(event_stream):
frames.append(frame) frames.append(frame)
return frames return frames
# ── HomeFormatter ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_formatter_text_stream() -> None: async def test_home_formatter_plain_text():
formatter = StreamFormatter(request_id="req-1") req_id = "req-1"
frames = await _collect( events = [
formatter, ("token", "Hello world"),
_stream(("token", "Hello"), ("token", " world")), ("mutations", []),
) ]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsStreamStart) assert isinstance(frames[0], WsStreamStart)
assert isinstance(frames[1], WsStreamText) assert frames[0].request_id == req_id
assert frames[1].chunk == "Hello" text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert isinstance(frames[2], WsStreamText) assert any("Hello world" in f.chunk for f in text_frames)
assert frames[2].chunk == " world"
assert isinstance(frames[-1], WsStreamEnd) assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_formatter_floating_domain_first() -> None: async def test_home_formatter_entity_tags_passed_through():
formatter = StreamFormatter(request_id="req-2") """Entity tags are streamed as-is — the frontend parses them."""
frames = await _collect( req_id = "req-2"
formatter, events = [
_stream( ("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
( ("mutations", []),
"floating_domain", ]
{"type": "node", "id": "n-1", "section": None}, formatter = HomeFormatter(request_id=req_id)
), frames = await collect(formatter, _stream(*events))
("token", "Summary"),
), text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
) assert "<project>[abc-123]</project>" in text
assert "Here is your project:" in text
assert "All good." in text
@pytest.mark.asyncio
async def test_home_formatter_multiple_tags_passed_through():
req_id = "req-3"
events = [
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
assert "<project>[p1]</project>" in text
assert "<task>[t1,t2]</task>" in text
@pytest.mark.asyncio
async def test_home_formatter_tool_end_ignored():
"""tool_end events are silently ignored by HomeFormatter."""
req_id = "req-4"
events = [
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
("token", "No tags here."),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
assert text == "No tags here."
@pytest.mark.asyncio
async def test_home_formatter_mutations_in_stream_end():
req_id = "req-5"
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
events = [
("token", "Done"),
("mutations", muts),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
end_frame = frames[-1]
assert isinstance(end_frame, WsStreamEnd)
assert len(end_frame.mutations) == 1
assert end_frame.mutations[0]["action"] == "insert"
@pytest.mark.asyncio
async def test_home_formatter_frame_order():
"""stream_start is first, stream_end is last."""
req_id = "req-6"
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
assert isinstance(frames[0], WsStreamStart)
assert isinstance(frames[-1], WsStreamEnd)
# ── FloatingFormatter ─────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_floating_formatter_domain_from_tool_end():
req_id = "pop-1"
formatter = FloatingFormatter(request_id=req_id)
events = [
("tool_end", {"name": "task_agent", "result": "ok"}),
("token", "Hello"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain) assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain.type == "node" assert frames[0].domain == "tasks"
assert frames[0].domain.id == "n-1" assert frames[0].request_id == req_id
assert isinstance(frames[1], WsStreamStart)
assert isinstance(frames[2], WsStreamText)
assert frames[2].chunk == "Summary" @pytest.mark.asyncio
async def test_floating_formatter_text_only():
req_id = "pop-2"
formatter = FloatingFormatter(request_id=req_id)
events = [
("tool_end", {"name": "timeline_agent", "result": "done"}),
("token", "Summary"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "timelines"
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert len(text_frames) == 1
assert text_frames[0].chunk == "Summary"
@pytest.mark.asyncio
async def test_floating_formatter_no_entity_tags():
"""FloatingFormatter never emits entity tag blocks."""
req_id = "pop-3"
formatter = FloatingFormatter(request_id=req_id)
events = [
("tool_end", {"name": "note_agent", "result": "data"}),
("token", "some text"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
# Only expected frame types
for f in frames:
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
@pytest.mark.asyncio
async def test_floating_formatter_end_frame():
req_id = "pop-4"
formatter = FloatingFormatter(request_id=req_id)
events = [
("tool_end", {"name": "project_agent", "result": "ok"}),
("token", "Done"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[-1], WsStreamEnd) assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_formatter_ignores_unknown_events() -> None: async def test_floating_formatter_default_domain_on_early_token():
formatter = StreamFormatter(request_id="req-3") """When the first event is a token (no tool_end yet), default to 'tasks'."""
frames = await _collect( req_id = "pop-5"
formatter, formatter = FloatingFormatter(request_id=req_id)
_stream(("tool_end", {"name": "x"}), ("token", "ok")), events = [("token", "hi"), ("mutations", [])]
) frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain)
text_frames = [f for f in frames if isinstance(f, WsStreamText)] assert frames[0].domain == "tasks"
assert len(text_frames) == 1
assert text_frames[0].chunk == "ok"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_formatter_empty_stream_still_brackets() -> None: async def test_floating_formatter_mutations_in_stream_end():
formatter = StreamFormatter(request_id="req-4") req_id = "pop-6"
frames = await _collect(formatter, _stream()) muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
events = [
("token", "Updated"),
("mutations", muts),
]
formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
assert len(frames) == 2 end_frame = frames[-1]
assert isinstance(frames[0], WsStreamStart) assert isinstance(end_frame, WsStreamEnd)
assert isinstance(frames[1], WsStreamEnd) assert len(end_frame.mutations) == 1

400
tests/test_plugins.py Normal file
View File

@@ -0,0 +1,400 @@
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
Covers:
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
- ReviewQueue: pending queue, review decisions, manifest security checklist
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
- Route integration: tier gate, list/get/install/uninstall via TestClient
"""
from __future__ import annotations
import uuid
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import PluginRegistry
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
from app.marketplace.revenue_share import RevenueShare
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
from app.schemas import PluginManifest
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fresh_manifest(
plugin_id: str | None = None,
category: str = "productivity",
price_cents: int = 0,
permissions: list[str] | None = None,
) -> PluginManifest:
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
return PluginManifest(
id=pid,
name=f"Plugin {pid}",
description=f"Description for {pid}",
version="1.0.0",
author="test-author",
permissions=permissions or ["read:tasks"],
category=category,
price_cents=price_cents,
)
# ---------------------------------------------------------------------------
# PluginRegistry (DB-backed)
# ---------------------------------------------------------------------------
class TestPluginRegistry:
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.mark.asyncio
async def test_seed_plugins_are_listed(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session)
assert result.total == 3
assert all(p.id.startswith("plugin-") for p in result.plugins)
@pytest.mark.asyncio
async def test_list_approved_only(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
result = await reg.list_plugins(db_session)
ids = [p.id for p in result.plugins]
assert manifest.id not in ids # still pending
@pytest.mark.asyncio
async def test_list_filter_by_category(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, category="communication")
assert result.total == 1
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_list_filter_by_query(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, query="time tracker")
assert result.total == 1
assert result.plugins[0].id == "plugin-time-tracker"
@pytest.mark.asyncio
async def test_list_sort_by_installs(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-slack-notify")
await reg.record_install(db_session, "plugin-slack-notify")
result = await reg.list_plugins(db_session, sort="installs")
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_get_plugin_found(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["manifest"].id == "plugin-github-sync"
assert "install_count" in entry
@pytest.mark.asyncio
async def test_get_plugin_not_found(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
entry = await reg.get_plugin(db_session, "no-such-plugin")
assert entry is None
@pytest.mark.asyncio
async def test_submit_sets_pending(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
assert plugin_id == manifest.id
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one()
assert row.status == "pending_review"
@pytest.mark.asyncio
async def test_approve_makes_visible(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.approve_plugin(db_session, manifest.id)
result = await reg.list_plugins(db_session)
assert manifest.id in [p.id for p in result.plugins]
@pytest.mark.asyncio
async def test_reject_stores_reason(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
assert row.rejection_reason == "Unsafe permissions"
listed = await reg.list_plugins(db_session)
assert manifest.id not in [p.id for p in listed.plugins]
@pytest.mark.asyncio
async def test_approve_unknown_raises_key_error(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
with pytest.raises(KeyError):
await reg.approve_plugin(db_session, "ghost-plugin")
@pytest.mark.asyncio
async def test_record_install_increments_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_decrements_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_floors_at_zero(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 0
# ---------------------------------------------------------------------------
# ReviewQueue (DB-backed)
# ---------------------------------------------------------------------------
class TestReviewQueue:
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.fixture
def queue(self) -> ReviewQueue:
return ReviewQueue()
@pytest.mark.asyncio
async def test_get_pending_returns_submitted_plugins(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
pending = await queue.get_pending(db_session)
assert any(p["plugin_id"] == manifest.id for p in pending)
@pytest.mark.asyncio
async def test_submit_review_approved(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "approved"
# Check review row was persisted
review_result = await db_session.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
)
review = review_result.scalar_one()
assert review.decision == "approved"
@pytest.mark.asyncio
async def test_submit_review_rejected(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
)
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
def test_validate_manifest_ok(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
validate_manifest(manifest) # should not raise
def test_validate_manifest_unknown_permission(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
with pytest.raises(ValueError, match="Unknown permission"):
validate_manifest(manifest)
def test_validate_manifest_invalid_id_format(self) -> None:
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
def test_validate_manifest_id_with_uppercase(self) -> None:
manifest = _fresh_manifest(plugin_id="UpperCase")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
# ---------------------------------------------------------------------------
# RevenueShare (DB-backed)
# ---------------------------------------------------------------------------
class TestRevenueShare:
@pytest.fixture
def rs(self) -> RevenueShare:
return RevenueShare()
@pytest.mark.asyncio
async def test_record_install_free_plugin(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
)
event = result.scalar_one()
assert event.developer_share_cents == 0
@pytest.mark.asyncio
async def test_record_install_paid_plugin_no_stripe(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
)
event = result.scalar_one()
assert event.amount_cents == 499
assert event.developer_share_cents == int(499 * 0.70)
@pytest.mark.asyncio
async def test_record_install_increments_registry_count(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
reg = PluginRegistry()
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_get_earnings_empty(
self, rs: RevenueShare, db_session: AsyncSession
) -> None:
result = await rs.get_earnings(db_session, "unknown-dev")
assert result["total_installs"] == 0
assert result["total_revenue_cents"] == 0
assert result["developer_share_cents"] == 0
@pytest.mark.asyncio
async def test_get_earnings_aggregates(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
result = await rs.get_earnings(db_session, "Adiuva")
assert result["total_installs"] == 2
assert result["total_revenue_cents"] == 998
assert result["developer_share_cents"] == int(499 * 0.70) * 2
# ---------------------------------------------------------------------------
# Route integration tests
# ---------------------------------------------------------------------------
class TestPluginRoutes:
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
assert resp.status_code == 403
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
assert resp.status_code == 403
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
assert resp.status_code == 200
data = resp.json()
assert "plugins" in data
assert data["total"] == 3
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
assert resp.status_code == 200
def test_get_plugin_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
assert resp.status_code == 200
data = resp.json()
assert data["plugin"]["id"] == "plugin-github-sync"
assert "install_count" in data
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
assert resp.status_code == 404
def test_install_plugin_free(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header(),
)
assert resp.status_code == 200
data = resp.json()
assert data["ok"] is True
assert "download_url" in data
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/ghost/install",
json={"plugin_id": "ghost"},
headers=auth_header(),
)
assert resp.status_code == 404
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
resp = client.delete(
"/api/v1/plugins/plugin-github-sync/install",
headers=auth_header(),
)
assert resp.status_code == 200
assert resp.json()["ok"] is True
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header("free"),
)
assert resp.status_code == 403

View File

@@ -4,7 +4,6 @@ import pytest
from pydantic import ValidationError from pydantic import ValidationError
from app.schemas import ( from app.schemas import (
WsDomain,
WsFrameType, WsFrameType,
WsHomeRequest, WsHomeRequest,
WsFloatingDomain, WsFloatingDomain,
@@ -179,15 +178,23 @@ def test_stream_text_deserializes():
def test_stream_end_defaults(): def test_stream_end_defaults():
frame = WsStreamEnd(request_id="r1") frame = WsStreamEnd(request_id="r1")
assert frame.type == WsFrameType.stream_end assert frame.type == WsFrameType.stream_end
assert frame.mutations == []
def test_stream_end_with_mutations():
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
frame = WsStreamEnd(request_id="r1", mutations=mutations)
assert len(frame.mutations) == 1
assert frame.mutations[0]["action"] == "create"
def test_stream_end_serializes(): def test_stream_end_serializes():
data = WsStreamEnd(request_id="r2").model_dump() data = WsStreamEnd(request_id="r2").model_dump()
assert data == {"type": "stream_end", "request_id": "r2"} assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
def test_stream_end_deserializes(): def test_stream_end_deserializes():
raw = {"type": "stream_end", "request_id": "r3"} raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
frame = WsStreamEnd.model_validate(raw) frame = WsStreamEnd.model_validate(raw)
assert frame.request_id == "r3" assert frame.request_id == "r3"
@@ -196,47 +203,28 @@ def test_stream_end_deserializes():
def test_floating_domain_tasks(): def test_floating_domain_tasks():
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task")) frame = WsFloatingDomain(request_id="r1", domain="tasks")
assert frame.type == WsFrameType.floating_domain assert frame.type == WsFrameType.floating_domain
assert frame.domain.type == "task" assert frame.domain == "tasks"
def test_floating_domain_valid_domains(): @pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
frame = WsFloatingDomain( def test_floating_domain_valid_domains(domain: str):
request_id="r1", frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"), assert frame.domain == domain
)
assert frame.domain.type == "project"
assert frame.domain.id == "213213-312321-312312-421321"
assert frame.domain.section == "task"
def test_floating_domain_object_valid(): def test_floating_domain_invalid():
frame = WsFloatingDomain( with pytest.raises(ValidationError):
request_id="r1", WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
domain=WsDomain(type="project", id="p1", section="task"),
)
assert frame.domain.type == "project"
def test_floating_domain_serializes(): def test_floating_domain_serializes():
d = WsFloatingDomain( d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
request_id="r1", assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
domain=WsDomain(type="timeline"),
).model_dump()
assert d == {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "timeline", "id": None, "section": None},
}
def test_floating_domain_deserializes(): def test_floating_domain_deserializes():
raw = { raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "node", "id": "n-1", "section": None},
}
frame = WsFloatingDomain.model_validate(raw) frame = WsFloatingDomain.model_validate(raw)
assert frame.domain.type == "node" assert frame.domain == "projects"
assert frame.domain.id == "n-1"

562
tests/test_storage.py Normal file
View File

@@ -0,0 +1,562 @@
"""Tests for the storage layer: encryption, BlobStore, VectorStore, and storage routes."""
from __future__ import annotations
import base64
import hashlib
from unittest.mock import MagicMock, patch
import boto3
import pytest
from botocore.exceptions import ClientError
from app.storage.encryption import reject_if_tampered, verify_checksum
from app.storage.blob_store import BlobStore
from app.storage.vector_store import VectorStore, _blob_to_vector
from app.schemas import VectorItem, VectorSearchResult
from tests.conftest import auth_header, S3_TEST_BUCKET
# ── Helpers ───────────────────────────────────────────────────────────
_BLOB = b"encrypted-payload-opaque-to-server"
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
_BUCKET = S3_TEST_BUCKET
_REGION = "us-east-1"
def _pinecone_mock():
"""Return a mock Pinecone index with realistic return shapes."""
mock_index = MagicMock()
mock_index.query.return_value = {
"matches": [
{
"id": "v1",
"score": 0.95,
"metadata": {
"blob": base64.b64encode(b"result-blob").decode(),
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
"user_id": "u1",
},
}
]
}
mock_pc = MagicMock()
mock_pc.return_value.Index.return_value = mock_index
return mock_pc, mock_index
# ── TestEncryption ────────────────────────────────────────────────────
class TestEncryption:
def test_verify_checksum_correct(self) -> None:
assert verify_checksum(_BLOB, _CHECKSUM) is True
def test_verify_checksum_wrong(self) -> None:
assert verify_checksum(_BLOB, "0" * 64) is False
def test_verify_checksum_empty_checksum(self) -> None:
assert verify_checksum(_BLOB, "") is False
def test_verify_checksum_empty_blob(self) -> None:
expected = hashlib.sha256(b"").hexdigest()
assert verify_checksum(b"", expected) is True
def test_verify_checksum_tampered_blob(self) -> None:
tampered = _BLOB + b"\x00"
assert verify_checksum(tampered, _CHECKSUM) is False
def test_reject_if_tampered_passes_when_valid(self) -> None:
# Should not raise
reject_if_tampered(_BLOB, _CHECKSUM)
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
reject_if_tampered(_BLOB, "bad" * 20)
assert exc_info.value.status_code == 400
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
reject_if_tampered(_BLOB, "bad" * 20)
assert "checksum" in exc_info.value.detail.lower()
def test_checksum_is_sha256_hex(self) -> None:
cs = hashlib.sha256(_BLOB).hexdigest()
assert len(cs) == 64
assert all(c in "0123456789abcdef" for c in cs)
# ── TestBlobStore ─────────────────────────────────────────────────────
class TestBlobStore:
@pytest.mark.asyncio
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
store = BlobStore()
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
assert key == "u1/tasks/r1"
@pytest.mark.asyncio
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
# Verify by downloading — no exception means object exists
retrieved = await store.download("u1", "u1/tasks/r1")
assert retrieved == _BLOB
@pytest.mark.asyncio
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
result = await store.download("u1", "u1/notes/n1")
assert result == b"note-data"
@pytest.mark.asyncio
async def test_delete_removes_object(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.delete("u1", "u1/tasks/r1")
with pytest.raises(ClientError) as exc_info:
await store.download("u1", "u1/tasks/r1")
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
@pytest.mark.asyncio
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
store = BlobStore()
# Delete a key that never existed — should not raise
await store.delete("u1", "u1/tasks/nonexistent")
@pytest.mark.asyncio
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
keys = await store.list_keys("u1", "tasks")
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
@pytest.mark.asyncio
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
keys = await store.list_keys("u1", "tasks")
assert "u1/notes/n1" not in keys
assert "u1/tasks/r1" in keys
@pytest.mark.asyncio
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
keys_u1 = await store.list_keys("u1", "tasks")
assert "u2/tasks/r1" not in keys_u1
@pytest.mark.asyncio
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
store = BlobStore()
keys = await store.list_keys("u1", "tasks")
assert keys == []
@pytest.mark.asyncio
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
# Verify S3 metadata was set — check via head_object
with patch("app.storage.blob_store.settings") as mock_settings:
mock_settings.S3_BUCKET = _BUCKET
mock_settings.S3_REGION = _REGION
mock_settings.AWS_ACCESS_KEY_ID = "testing"
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
client = boto3.client("s3", region_name=_REGION)
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
assert response.get("ServerSideEncryption") == "AES256"
@pytest.mark.asyncio
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
client = boto3.client("s3", region_name=_REGION)
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
assert response["Metadata"]["checksum"] == _CHECKSUM
# ── _blob_to_vector helper ────────────────────────────────────────────
class TestBlobToVector:
def test_returns_32_floats(self) -> None:
v = _blob_to_vector(b"test")
assert len(v) == 32
def test_all_values_in_range(self) -> None:
v = _blob_to_vector(b"test")
assert all(-1.0 <= x <= 1.0 for x in v)
def test_deterministic(self) -> None:
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
def test_different_blobs_different_vectors(self) -> None:
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
# ── TestVectorStorePinecone ───────────────────────────────────────────
class TestVectorStorePinecone:
def _store(self) -> VectorStore:
store = VectorStore()
store._use_pinecone = lambda: True # type: ignore[method-assign]
return store
@pytest.mark.asyncio
async def test_upsert_calls_index_upsert(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
await store.upsert("u1", items)
mock_index.upsert.assert_called_once()
call_kwargs = mock_index.upsert.call_args[1]
assert call_kwargs.get("namespace") == "u1"
@pytest.mark.asyncio
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
await store.upsert("u1", items)
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
@pytest.mark.asyncio
async def test_search_calls_index_query(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.search("u1", b"query-blob", top_k=5)
mock_index.query.assert_called_once()
query_kwargs = mock_index.query.call_args[1]
assert query_kwargs.get("namespace") == "u1"
assert query_kwargs.get("top_k") == 5
assert query_kwargs.get("include_metadata") is True
@pytest.mark.asyncio
async def test_search_returns_vector_search_results(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
results = await store.search("u1", b"query", top_k=10)
assert len(results) == 1
assert isinstance(results[0], VectorSearchResult)
assert results[0].id == "v1"
assert results[0].score == 0.95
assert results[0].blob == b"result-blob"
@pytest.mark.asyncio
async def test_search_uses_derived_query_vector(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.search("u1", b"query-blob", top_k=3)
expected_vector = _blob_to_vector(b"query-blob")
actual_vector = mock_index.query.call_args[1].get("vector")
assert actual_vector == expected_vector
@pytest.mark.asyncio
async def test_delete_calls_index_delete(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.delete("u1", ["v1", "v2"])
mock_index.delete.assert_called_once()
delete_kwargs = mock_index.delete.call_args[1]
assert delete_kwargs.get("namespace") == "u1"
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
class TestVectorStoreQdrant:
def _store(self) -> VectorStore:
store = VectorStore()
store._use_pinecone = lambda: False # type: ignore[method-assign]
return store
def _qdrant_mock(self) -> MagicMock:
mock_hit = MagicMock()
mock_hit.id = "v1"
mock_hit.score = 0.88
mock_hit.payload = {
"blob": base64.b64encode(b"qdrant-result").decode(),
"user_id": "u1",
}
mock_client = MagicMock()
mock_client.search.return_value = [mock_hit]
return mock_client
@pytest.mark.asyncio
async def test_upsert_calls_client_upsert(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
await store.upsert("u1", items)
mock_client.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_upsert_uses_correct_collection(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
await store.upsert("u1", items)
call_kwargs = mock_client.upsert.call_args[1]
assert call_kwargs["collection_name"] == "adiuva_vectors"
@pytest.mark.asyncio
async def test_search_calls_client_search(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.search("u1", b"query", top_k=5)
mock_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_search_passes_limit(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.search("u1", b"query", top_k=7)
call_kwargs = mock_client.search.call_args[1]
assert call_kwargs.get("limit") == 7
@pytest.mark.asyncio
async def test_search_returns_vector_search_results(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
results = await store.search("u1", b"query", top_k=5)
assert len(results) == 1
assert isinstance(results[0], VectorSearchResult)
assert results[0].id == "v1"
assert results[0].score == 0.88
assert results[0].blob == b"qdrant-result"
@pytest.mark.asyncio
async def test_delete_calls_client_delete(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.delete("u1", ["v1", "v2"])
mock_client.delete.assert_called_once()
@pytest.mark.asyncio
async def test_delete_uses_correct_collection(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.delete("u1", ["v1"])
call_kwargs = mock_client.delete.call_args[1]
assert call_kwargs["collection_name"] == "adiuva_vectors"
# ── TestStorageRoutes (integration) ───────────────────────────────────
class TestStorageRoutes:
"""Integration tests for POST/GET/PUT/DELETE /api/v1/storage/records.
Pydantic v2 converts JSON string → bytes via ``str.encode('utf-8')``.
So "hello" in JSON becomes ``b"hello"`` on the server. We use plain
ASCII strings as blob values and compute checksums accordingly.
"""
_BLOB_STR = "encrypted-payload-opaque-to-server"
_BLOB_BYTES = _BLOB_STR.encode()
_BLOB_CHECKSUM = hashlib.sha256(_BLOB_BYTES).hexdigest()
@classmethod
def _create_payload(cls, blob_str: str | None = None) -> dict:
blob_str = blob_str or cls._BLOB_STR
checksum = hashlib.sha256(blob_str.encode()).hexdigest()
return {
"table": "tasks",
"blob": blob_str,
"checksum": checksum,
}
def _create_record(self, client, tier="power", blob_str=None):
payload = self._create_payload(blob_str)
return client.post(
"/api/v1/storage/records",
json=payload,
headers=auth_header(tier),
)
# ── Create ────────────────────────────────────────────────────────
def test_create_record(self, client, s3_bucket) -> None:
resp = self._create_record(client)
assert resp.status_code == 201
data = resp.json()
assert "id" in data
assert "created_at" in data
def test_create_record_bad_checksum(self, client, s3_bucket) -> None:
payload = {
"table": "tasks",
"blob": self._BLOB_STR,
"checksum": "0" * 64,
}
resp = client.post(
"/api/v1/storage/records",
json=payload,
headers=auth_header("power"),
)
assert resp.status_code == 400
def test_create_record_free_tier_blocked(self, client, s3_bucket) -> None:
"""Free tier has cloud_storage_gb=0 → 402."""
resp = self._create_record(client, tier="free")
assert resp.status_code == 402
def test_create_record_pro_tier_allowed(self, client, s3_bucket) -> None:
"""Pro tier has cloud_storage_gb=5 → succeeds for small blob."""
resp = self._create_record(client, tier="pro")
assert resp.status_code == 201
# ── List ──────────────────────────────────────────────────────────
def test_list_records(self, client, s3_bucket) -> None:
self._create_record(client)
self._create_record(client, blob_str="second-blob")
resp = client.get(
"/api/v1/storage/records",
headers=auth_header("power"),
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
# Each entry has metadata, no blob bytes
for item in data:
assert "id" in item
assert "table" in item
assert "checksum" in item
assert "blob" not in item
def test_list_records_filter_by_table(self, client, s3_bucket) -> None:
self._create_record(client)
# Create in a different table
note_blob = "note-blob"
payload = {
"table": "notes",
"blob": note_blob,
"checksum": hashlib.sha256(note_blob.encode()).hexdigest(),
}
client.post(
"/api/v1/storage/records",
json=payload,
headers=auth_header("power"),
)
resp = client.get(
"/api/v1/storage/records?table=notes",
headers=auth_header("power"),
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["table"] == "notes"
def test_list_records_isolated_per_user(self, client, s3_bucket) -> None:
"""One user's records should not appear in another user's list."""
self._create_record(client, tier="power")
resp = client.get(
"/api/v1/storage/records",
headers=auth_header("team"),
)
assert resp.json() == []
# ── Download ──────────────────────────────────────────────────────
def test_download_record(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
resp = client.get(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert resp.status_code == 200
assert resp.content == self._BLOB_BYTES
assert resp.headers["X-Checksum"] == self._BLOB_CHECKSUM
def test_download_record_not_found(self, client, s3_bucket) -> None:
resp = client.get(
"/api/v1/storage/records/nonexistent-id",
headers=auth_header("power"),
)
assert resp.status_code == 404
# ── Update ────────────────────────────────────────────────────────
def test_update_record(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
new_blob_str = "updated-encrypted-payload"
new_checksum = hashlib.sha256(new_blob_str.encode()).hexdigest()
resp = client.put(
f"/api/v1/storage/records/{record_id}",
json={"blob": new_blob_str, "checksum": new_checksum},
headers=auth_header("power"),
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# Verify download returns the updated blob
dl = client.get(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert dl.content == new_blob_str.encode()
def test_update_record_bad_checksum(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
resp = client.put(
f"/api/v1/storage/records/{record_id}",
json={"blob": "some-data", "checksum": "0" * 64},
headers=auth_header("power"),
)
assert resp.status_code == 400
# ── Delete ────────────────────────────────────────────────────────
def test_delete_record(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
resp = client.delete(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# Subsequent GET should return 404
dl = client.get(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert dl.status_code == 404
def test_delete_record_not_found(self, client, s3_bucket) -> None:
resp = client.delete(
"/api/v1/storage/records/nonexistent",
headers=auth_header("power"),
)
assert resp.status_code == 404

View File

@@ -45,13 +45,15 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
return frames return frames
async def _mock_home_stream(user_id, message, context): async def _mock_home_stream(user_id, message, context, db_session_factory=None):
yield "token", "Hello" yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
yield "mutations", []
async def _mock_floating_stream(user_id, message, context): async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
yield "floating_domain", {"type": "task", "id": None, "section": None} yield "tool_end", {"name": "task_agent", "result": "ok"}
yield "token", "Here is a summary" yield "token", "Here is a summary"
yield "mutations", []
# ── tests ───────────────────────────────────────────────────────────────────── # ── tests ─────────────────────────────────────────────────────────────────────
@@ -102,7 +104,7 @@ def test_floating_request_produces_domain_frame(client):
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end) assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain) domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"]["type"] == "task" assert domain_frame["domain"] == "tasks"
assert domain_frame["request_id"] == "p1" assert domain_frame["request_id"] == "p1"
@@ -111,8 +113,9 @@ def test_home_request_request_id_propagated(client):
token = make_jwt("power", user_id=USER_ID) token = make_jwt("power", user_id=USER_ID)
req_id = "my-unique-req-id" req_id = "my-unique-req-id"
async def _stream(user_id, message, context): async def _stream(user_id, message, context, db_session_factory=None):
yield "token", "ok" yield "token", "ok"
yield "mutations", []
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream): with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: